From a13a5ed67badf6fd5684a9752e1fd44c2e7f5134 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 3 Oct 2023 21:18:47 +0900 Subject: [PATCH 01/19] Mark 3.12 --- .github/workflows/ci.yaml | 2 +- .github/workflows/cron-ci.yaml | 2 +- DEVELOPMENT.md | 2 +- README.md | 2 +- vm/src/version.rs | 4 ++-- whats_left.py | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b4e1eb1932..e616f6d10c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -105,7 +105,7 @@ env: test_weakref test_yield_from # Python version targeted by the CI. - PYTHON_VERSION: "3.11.4" + PYTHON_VERSION: "3.12.0" jobs: rust_tests: diff --git a/.github/workflows/cron-ci.yaml b/.github/workflows/cron-ci.yaml index ee90aac4a9..9176f232c7 100644 --- a/.github/workflows/cron-ci.yaml +++ b/.github/workflows/cron-ci.yaml @@ -7,7 +7,7 @@ name: Periodic checks/tasks env: CARGO_ARGS: --no-default-features --features stdlib,zlib,importlib,encodings,ssl,jit - PYTHON_VERSION: "3.11.4" + PYTHON_VERSION: "3.12.0" jobs: # codecov collects code coverage data from the rust tests, python snippets and python test suite. diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 4ef49abe94..7c79a011ba 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -25,7 +25,7 @@ RustPython requires the following: stable version: `rustup update stable` - If you do not have Rust installed, use [rustup](https://rustup.rs/) to do so. -- CPython version 3.11 or higher +- CPython version 3.12 or higher - CPython can be installed by your operating system's package manager, from the [Python website](https://www.python.org/downloads/), or using a third-party distribution, such as diff --git a/README.md b/README.md index 2c8266cb12..65d57a7ee5 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # [RustPython](https://rustpython.github.io/) -A Python-3 (CPython >= 3.11.0) Interpreter written in Rust :snake: :scream: +A Python-3 (CPython >= 3.12.0) Interpreter written in Rust :snake: :scream: :metal:. [![Build Status](https://github.com/RustPython/RustPython/workflows/CI/badge.svg)](https://github.com/RustPython/RustPython/actions?query=workflow%3ACI) diff --git a/vm/src/version.rs b/vm/src/version.rs index ec23e896b0..9a75f71142 100644 --- a/vm/src/version.rs +++ b/vm/src/version.rs @@ -4,9 +4,9 @@ use chrono::{prelude::DateTime, Local}; use std::time::{Duration, UNIX_EPOCH}; -// = 3.11.0alpha +// = 3.12.0alpha pub const MAJOR: usize = 3; -pub const MINOR: usize = 11; +pub const MINOR: usize = 12; pub const MICRO: usize = 0; pub const RELEASELEVEL: &str = "alpha"; pub const RELEASELEVEL_N: usize = 0xA; diff --git a/whats_left.py b/whats_left.py index 7f3ad80c63..4f087f89af 100755 --- a/whats_left.py +++ b/whats_left.py @@ -35,8 +35,8 @@ implementation = platform.python_implementation() if implementation != "CPython": sys.exit(f"whats_left.py must be run under CPython, got {implementation} instead") -if sys.version_info[:2] < (3, 11): - sys.exit(f"whats_left.py must be run under CPython 3.11 or newer, got {implementation} {sys.version} instead") +if sys.version_info[:2] < (3, 12): + sys.exit(f"whats_left.py must be run under CPython 3.12 or newer, got {implementation} {sys.version} instead") def parse_args(): parser = argparse.ArgumentParser(description="Process some integers.") From 49dddceef6c490093d64c191c8e9196493fc0437 Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Tue, 3 Oct 2023 22:50:09 +0900 Subject: [PATCH 02/19] Update importlib from Python 3.12.0 --- Lib/importlib/__init__.py | 46 +-- Lib/importlib/_abc.py | 15 - Lib/importlib/_bootstrap.py | 454 ++++++++++++++++++-------- Lib/importlib/_bootstrap_external.py | 249 +++++++------- Lib/importlib/abc.py | 111 +------ Lib/importlib/metadata/__init__.py | 305 ++++++----------- Lib/importlib/metadata/_adapters.py | 21 ++ Lib/importlib/metadata/_meta.py | 28 +- Lib/importlib/resources/_adapters.py | 4 +- Lib/importlib/resources/_common.py | 148 +++++++-- Lib/importlib/resources/_itertools.py | 69 ++-- Lib/importlib/resources/_legacy.py | 3 +- Lib/importlib/resources/abc.py | 26 +- Lib/importlib/resources/readers.py | 50 ++- Lib/importlib/resources/simple.py | 79 ++--- Lib/importlib/util.py | 144 +++----- 16 files changed, 884 insertions(+), 868 deletions(-) diff --git a/Lib/importlib/__init__.py b/Lib/importlib/__init__.py index ce61883288..707c081cb2 100644 --- a/Lib/importlib/__init__.py +++ b/Lib/importlib/__init__.py @@ -70,41 +70,6 @@ def invalidate_caches(): finder.invalidate_caches() -def find_loader(name, path=None): - """Return the loader for the specified module. - - This is a backward-compatible wrapper around find_spec(). - - This function is deprecated in favor of importlib.util.find_spec(). - - """ - warnings.warn('Deprecated since Python 3.4 and slated for removal in ' - 'Python 3.12; use importlib.util.find_spec() instead', - DeprecationWarning, stacklevel=2) - try: - loader = sys.modules[name].__loader__ - if loader is None: - raise ValueError('{}.__loader__ is None'.format(name)) - else: - return loader - except KeyError: - pass - except AttributeError: - raise ValueError('{}.__loader__ is not set'.format(name)) from None - - spec = _bootstrap._find_spec(name, path) - # We won't worry about malformed specs (missing attributes). - if spec is None: - return None - if spec.loader is None: - if spec.submodule_search_locations is None: - raise ImportError('spec for {} missing loader'.format(name), - name=name) - raise ImportError('namespace packages do not have loaders', - name=name) - return spec.loader - - def import_module(name, package=None): """Import a module. @@ -116,9 +81,8 @@ def import_module(name, package=None): level = 0 if name.startswith('.'): if not package: - msg = ("the 'package' argument is required to perform a relative " - "import for {!r}") - raise TypeError(msg.format(name)) + raise TypeError("the 'package' argument is required to perform a " + f"relative import for {name!r}") for character in name: if character != '.': break @@ -144,8 +108,7 @@ def reload(module): raise TypeError("reload() argument must be a module") if sys.modules.get(name) is not module: - msg = "module {} not in sys.modules" - raise ImportError(msg.format(name), name=name) + raise ImportError(f"module {name} not in sys.modules", name=name) if name in _RELOADING: return _RELOADING[name] _RELOADING[name] = module @@ -155,8 +118,7 @@ def reload(module): try: parent = sys.modules[parent_name] except KeyError: - msg = "parent {!r} not in sys.modules" - raise ImportError(msg.format(parent_name), + raise ImportError(f"parent {parent_name!r} not in sys.modules", name=parent_name) from None else: pkgpath = parent.__path__ diff --git a/Lib/importlib/_abc.py b/Lib/importlib/_abc.py index f80348fc7f..693b466112 100644 --- a/Lib/importlib/_abc.py +++ b/Lib/importlib/_abc.py @@ -1,7 +1,6 @@ """Subset of importlib.abc used to reduce importlib.util imports.""" from . import _bootstrap import abc -import warnings class Loader(metaclass=abc.ABCMeta): @@ -38,17 +37,3 @@ def load_module(self, fullname): raise ImportError # Warning implemented in _load_module_shim(). return _bootstrap._load_module_shim(self, fullname) - - def module_repr(self, module): - """Return a module's repr. - - Used by the module type when the method does not raise - NotImplementedError. - - This method is deprecated. - - """ - warnings.warn("importlib.abc.Loader.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - # The exception will cause ModuleType.__repr__ to ignore this method. - raise NotImplementedError diff --git a/Lib/importlib/_bootstrap.py b/Lib/importlib/_bootstrap.py index b1fdad8e6d..093a0b8245 100644 --- a/Lib/importlib/_bootstrap.py +++ b/Lib/importlib/_bootstrap.py @@ -51,17 +51,178 @@ def _new_module(name): # Module-level locking ######################################################## -# A dict mapping module names to weakrefs of _ModuleLock instances -# Dictionary protected by the global import lock +# For a list that can have a weakref to it. +class _List(list): + pass + + +# Copied from weakref.py with some simplifications and modifications unique to +# bootstrapping importlib. Many methods were simply deleting for simplicity, so if they +# are needed in the future they may work if simply copied back in. +class _WeakValueDictionary: + + def __init__(self): + self_weakref = _weakref.ref(self) + + # Inlined to avoid issues with inheriting from _weakref.ref before _weakref is + # set by _setup(). Since there's only one instance of this class, this is + # not expensive. + class KeyedRef(_weakref.ref): + + __slots__ = "key", + + def __new__(type, ob, key): + self = super().__new__(type, ob, type.remove) + self.key = key + return self + + def __init__(self, ob, key): + super().__init__(ob, self.remove) + + @staticmethod + def remove(wr): + nonlocal self_weakref + + self = self_weakref() + if self is not None: + if self._iterating: + self._pending_removals.append(wr.key) + else: + _weakref._remove_dead_weakref(self.data, wr.key) + + self._KeyedRef = KeyedRef + self.clear() + + def clear(self): + self._pending_removals = [] + self._iterating = set() + self.data = {} + + def _commit_removals(self): + pop = self._pending_removals.pop + d = self.data + while True: + try: + key = pop() + except IndexError: + return + _weakref._remove_dead_weakref(d, key) + + def get(self, key, default=None): + if self._pending_removals: + self._commit_removals() + try: + wr = self.data[key] + except KeyError: + return default + else: + if (o := wr()) is None: + return default + else: + return o + + def setdefault(self, key, default=None): + try: + o = self.data[key]() + except KeyError: + o = None + if o is None: + if self._pending_removals: + self._commit_removals() + self.data[key] = self._KeyedRef(default, key) + return default + else: + return o + + +# A dict mapping module names to weakrefs of _ModuleLock instances. +# Dictionary protected by the global import lock. _module_locks = {} -# A dict mapping thread ids to _ModuleLock instances -_blocking_on = {} + +# A dict mapping thread IDs to weakref'ed lists of _ModuleLock instances. +# This maps a thread to the module locks it is blocking on acquiring. The +# values are lists because a single thread could perform a re-entrant import +# and be "in the process" of blocking on locks for more than one module. A +# thread can be "in the process" because a thread cannot actually block on +# acquiring more than one lock but it can have set up bookkeeping that reflects +# that it intends to block on acquiring more than one lock. +# +# The dictionary uses a WeakValueDictionary to avoid keeping unnecessary +# lists around, regardless of GC runs. This way there's no memory leak if +# the list is no longer needed (GH-106176). +_blocking_on = None + + +class _BlockingOnManager: + """A context manager responsible to updating ``_blocking_on``.""" + def __init__(self, thread_id, lock): + self.thread_id = thread_id + self.lock = lock + + def __enter__(self): + """Mark the running thread as waiting for self.lock. via _blocking_on.""" + # Interactions with _blocking_on are *not* protected by the global + # import lock here because each thread only touches the state that it + # owns (state keyed on its thread id). The global import lock is + # re-entrant (i.e., a single thread may take it more than once) so it + # wouldn't help us be correct in the face of re-entrancy either. + + self.blocked_on = _blocking_on.setdefault(self.thread_id, _List()) + self.blocked_on.append(self.lock) + + def __exit__(self, *args, **kwargs): + """Remove self.lock from this thread's _blocking_on list.""" + self.blocked_on.remove(self.lock) class _DeadlockError(RuntimeError): pass + +def _has_deadlocked(target_id, *, seen_ids, candidate_ids, blocking_on): + """Check if 'target_id' is holding the same lock as another thread(s). + + The search within 'blocking_on' starts with the threads listed in + 'candidate_ids'. 'seen_ids' contains any threads that are considered + already traversed in the search. + + Keyword arguments: + target_id -- The thread id to try to reach. + seen_ids -- A set of threads that have already been visited. + candidate_ids -- The thread ids from which to begin. + blocking_on -- A dict representing the thread/blocking-on graph. This may + be the same object as the global '_blocking_on' but it is + a parameter to reduce the impact that global mutable + state has on the result of this function. + """ + if target_id in candidate_ids: + # If we have already reached the target_id, we're done - signal that it + # is reachable. + return True + + # Otherwise, try to reach the target_id from each of the given candidate_ids. + for tid in candidate_ids: + if not (candidate_blocking_on := blocking_on.get(tid)): + # There are no edges out from this node, skip it. + continue + elif tid in seen_ids: + # bpo 38091: the chain of tid's we encounter here eventually leads + # to a fixed point or a cycle, but does not reach target_id. + # This means we would not actually deadlock. This can happen if + # other threads are at the beginning of acquire() below. + return False + seen_ids.add(tid) + + # Follow the edges out from this thread. + edges = [lock.owner for lock in candidate_blocking_on] + if _has_deadlocked(target_id, seen_ids=seen_ids, candidate_ids=edges, + blocking_on=blocking_on): + return True + + return False + + class _ModuleLock: """A recursive lock implementation which is able to detect deadlocks (e.g. thread 1 trying to take locks A then B, and thread 2 trying to @@ -69,33 +230,76 @@ class _ModuleLock: """ def __init__(self, name): - self.lock = _thread.allocate_lock() + # Create an RLock for protecting the import process for the + # corresponding module. Since it is an RLock, a single thread will be + # able to take it more than once. This is necessary to support + # re-entrancy in the import system that arises from (at least) signal + # handlers and the garbage collector. Consider the case of: + # + # import foo + # -> ... + # -> importlib._bootstrap._ModuleLock.acquire + # -> ... + # -> + # -> __del__ + # -> import foo + # -> ... + # -> importlib._bootstrap._ModuleLock.acquire + # -> _BlockingOnManager.__enter__ + # + # If a different thread than the running one holds the lock then the + # thread will have to block on taking the lock, which is what we want + # for thread safety. + self.lock = _thread.RLock() self.wakeup = _thread.allocate_lock() + + # The name of the module for which this is a lock. self.name = name + + # Can end up being set to None if this lock is not owned by any thread + # or the thread identifier for the owning thread. self.owner = None - self.count = 0 - self.waiters = 0 + + # Represent the number of times the owning thread has acquired this lock + # via a list of True. This supports RLock-like ("re-entrant lock") + # behavior, necessary in case a single thread is following a circular + # import dependency and needs to take the lock for a single module + # more than once. + # + # Counts are represented as a list of True because list.append(True) + # and list.pop() are both atomic and thread-safe in CPython and it's hard + # to find another primitive with the same properties. + self.count = [] + + # This is a count of the number of threads that are blocking on + # self.wakeup.acquire() awaiting to get their turn holding this module + # lock. When the module lock is released, if this is greater than + # zero, it is decremented and `self.wakeup` is released one time. The + # intent is that this will let one other thread make more progress on + # acquiring this module lock. This repeats until all the threads have + # gotten a turn. + # + # This is incremented in self.acquire() when a thread notices it is + # going to have to wait for another thread to finish. + # + # See the comment above count for explanation of the representation. + self.waiters = [] def has_deadlock(self): - # Deadlock avoidance for concurrent circular imports. - me = _thread.get_ident() - tid = self.owner - seen = set() - while True: - lock = _blocking_on.get(tid) - if lock is None: - return False - tid = lock.owner - if tid == me: - return True - if tid in seen: - # bpo 38091: the chain of tid's we encounter here - # eventually leads to a fixpoint or a cycle, but - # does not reach 'me'. This means we would not - # actually deadlock. This can happen if other - # threads are at the beginning of acquire() below. - return False - seen.add(tid) + # To avoid deadlocks for concurrent or re-entrant circular imports, + # look at _blocking_on to see if any threads are blocking + # on getting the import lock for any module for which the import lock + # is held by this thread. + return _has_deadlocked( + # Try to find this thread. + target_id=_thread.get_ident(), + seen_ids=set(), + # Start from the thread that holds the import lock for this + # module. + candidate_ids=[self.owner], + # Use the global "blocking on" state. + blocking_on=_blocking_on, + ) def acquire(self): """ @@ -104,39 +308,82 @@ def acquire(self): Otherwise, the lock is always acquired and True is returned. """ tid = _thread.get_ident() - _blocking_on[tid] = self - try: + with _BlockingOnManager(tid, self): while True: + # Protect interaction with state on self with a per-module + # lock. This makes it safe for more than one thread to try to + # acquire the lock for a single module at the same time. with self.lock: - if self.count == 0 or self.owner == tid: + if self.count == [] or self.owner == tid: + # If the lock for this module is unowned then we can + # take the lock immediately and succeed. If the lock + # for this module is owned by the running thread then + # we can also allow the acquire to succeed. This + # supports circular imports (thread T imports module A + # which imports module B which imports module A). self.owner = tid - self.count += 1 + self.count.append(True) return True + + # At this point we know the lock is held (because count != + # 0) by another thread (because owner != tid). We'll have + # to get in line to take the module lock. + + # But first, check to see if this thread would create a + # deadlock by acquiring this module lock. If it would + # then just stop with an error. + # + # It's not clear who is expected to handle this error. + # There is one handler in _lock_unlock_module but many + # times this method is called when entering the context + # manager _ModuleLockManager instead - so _DeadlockError + # will just propagate up to application code. + # + # This seems to be more than just a hypothetical - + # https://stackoverflow.com/questions/59509154 + # https://github.com/encode/django-rest-framework/issues/7078 if self.has_deadlock(): - raise _DeadlockError('deadlock detected by %r' % self) + raise _DeadlockError(f'deadlock detected by {self!r}') + + # Check to see if we're going to be able to acquire the + # lock. If we are going to have to wait then increment + # the waiters so `self.release` will know to unblock us + # later on. We do this part non-blockingly so we don't + # get stuck here before we increment waiters. We have + # this extra acquire call (in addition to the one below, + # outside the self.lock context manager) to make sure + # self.wakeup is held when the next acquire is called (so + # we block). This is probably needlessly complex and we + # should just take self.wakeup in the return codepath + # above. if self.wakeup.acquire(False): - self.waiters += 1 - # Wait for a release() call + self.waiters.append(None) + + # Now take the lock in a blocking fashion. This won't + # complete until the thread holding this lock + # (self.owner) calls self.release. self.wakeup.acquire() + + # Taking the lock has served its purpose (making us wait), so we can + # give it up now. We'll take it w/o blocking again on the + # next iteration around this 'while' loop. self.wakeup.release() - finally: - del _blocking_on[tid] def release(self): tid = _thread.get_ident() with self.lock: if self.owner != tid: raise RuntimeError('cannot release un-acquired lock') - assert self.count > 0 - self.count -= 1 - if self.count == 0: + assert len(self.count) > 0 + self.count.pop() + if not len(self.count): self.owner = None - if self.waiters: - self.waiters -= 1 + if len(self.waiters) > 0: + self.waiters.pop() self.wakeup.release() def __repr__(self): - return '_ModuleLock({!r}) at {}'.format(self.name, id(self)) + return f'_ModuleLock({self.name!r}) at {id(self)}' class _DummyModuleLock: @@ -157,7 +404,7 @@ def release(self): self.count -= 1 def __repr__(self): - return '_DummyModuleLock({!r}) at {}'.format(self.name, id(self)) + return f'_DummyModuleLock({self.name!r}) at {id(self)}' class _ModuleLockManager: @@ -254,7 +501,7 @@ def _requires_builtin(fxn): """Decorator to verify the named module is built-in.""" def _requires_builtin_wrapper(self, fullname): if fullname not in sys.builtin_module_names: - raise ImportError('{!r} is not a built-in module'.format(fullname), + raise ImportError(f'{fullname!r} is not a built-in module', name=fullname) return fxn(self, fullname) _wrap(_requires_builtin_wrapper, fxn) @@ -265,7 +512,7 @@ def _requires_frozen(fxn): """Decorator to verify the named module is frozen.""" def _requires_frozen_wrapper(self, fullname): if not _imp.is_frozen(fullname): - raise ImportError('{!r} is not a frozen module'.format(fullname), + raise ImportError(f'{fullname!r} is not a frozen module', name=fullname) return fxn(self, fullname) _wrap(_requires_frozen_wrapper, fxn) @@ -297,11 +544,6 @@ def _module_repr(module): loader = getattr(module, '__loader__', None) if spec := getattr(module, "__spec__", None): return _module_repr_from_spec(spec) - elif hasattr(loader, 'module_repr'): - try: - return loader.module_repr(module) - except Exception: - pass # Fall through to a catch-all which always succeeds. try: name = module.__name__ @@ -311,11 +553,11 @@ def _module_repr(module): filename = module.__file__ except AttributeError: if loader is None: - return ''.format(name) + return f'' else: - return ''.format(name, loader) + return f'' else: - return ''.format(name, filename) + return f'' class ModuleSpec: @@ -369,14 +611,12 @@ def __init__(self, name, loader, *, origin=None, loader_state=None, self._cached = None def __repr__(self): - args = ['name={!r}'.format(self.name), - 'loader={!r}'.format(self.loader)] + args = [f'name={self.name!r}', f'loader={self.loader!r}'] if self.origin is not None: - args.append('origin={!r}'.format(self.origin)) + args.append(f'origin={self.origin!r}') if self.submodule_search_locations is not None: - args.append('submodule_search_locations={}' - .format(self.submodule_search_locations)) - return '{}({})'.format(self.__class__.__name__, ', '.join(args)) + args.append(f'submodule_search_locations={self.submodule_search_locations}') + return f'{self.__class__.__name__}({", ".join(args)})' def __eq__(self, other): smsl = self.submodule_search_locations @@ -583,18 +823,17 @@ def module_from_spec(spec): def _module_repr_from_spec(spec): """Return the repr to use for the module.""" - # We mostly replicate _module_repr() using the spec attributes. name = '?' if spec.name is None else spec.name if spec.origin is None: if spec.loader is None: - return ''.format(name) + return f'' else: - return ''.format(name, spec.loader) + return f'' else: if spec.has_location: - return ''.format(name, spec.origin) + return f'' else: - return ''.format(spec.name, spec.origin) + return f'' # Used by importlib.reload() and _load_module_shim(). @@ -603,7 +842,7 @@ def _exec(spec, module): name = spec.name with _ModuleLockManager(name): if sys.modules.get(name) is not module: - msg = 'module {!r} not in sys.modules'.format(name) + msg = f'module {name!r} not in sys.modules' raise ImportError(msg, name=name) try: if spec.loader is None: @@ -735,46 +974,18 @@ class BuiltinImporter: _ORIGIN = "built-in" - @staticmethod - def module_repr(module): - """Return repr for the module. - - The method is deprecated. The import machinery does the job itself. - - """ - _warnings.warn("BuiltinImporter.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - return f'' - @classmethod def find_spec(cls, fullname, path=None, target=None): - if path is not None: - return None if _imp.is_builtin(fullname): return spec_from_loader(fullname, cls, origin=cls._ORIGIN) else: return None - @classmethod - def find_module(cls, fullname, path=None): - """Find the built-in module. - - If 'path' is ever specified then the search is considered a failure. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("BuiltinImporter.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = cls.find_spec(fullname, path) - return spec.loader if spec is not None else None - @staticmethod def create_module(spec): """Create a built-in module""" if spec.name not in sys.builtin_module_names: - raise ImportError('{!r} is not a built-in module'.format(spec.name), + raise ImportError(f'{spec.name!r} is not a built-in module', name=spec.name) return _call_with_frames_removed(_imp.create_builtin, spec) @@ -815,17 +1026,6 @@ class FrozenImporter: _ORIGIN = "frozen" - @staticmethod - def module_repr(m): - """Return repr for the module. - - The method is deprecated. The import machinery does the job itself. - - """ - _warnings.warn("FrozenImporter.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - return ''.format(m.__name__, FrozenImporter._ORIGIN) - @classmethod def _fix_up_module(cls, module): spec = module.__spec__ @@ -950,18 +1150,6 @@ def find_spec(cls, fullname, path=None, target=None): spec.submodule_search_locations.insert(0, pkgdir) return spec - @classmethod - def find_module(cls, fullname, path=None): - """Find a frozen module. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("FrozenImporter.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - return cls if _imp.is_frozen(fullname) else None - @staticmethod def create_module(spec): """Set __file__, if able.""" @@ -1041,17 +1229,7 @@ def _resolve_name(name, package, level): if len(bits) < level: raise ImportError('attempted relative import beyond top-level package') base = bits[0] - return '{}.{}'.format(base, name) if name else base - - -def _find_spec_legacy(finder, name, path): - msg = (f"{_object_name(finder)}.find_spec() not found; " - "falling back to find_module()") - _warnings.warn(msg, ImportWarning) - loader = finder.find_module(name, path) - if loader is None: - return None - return spec_from_loader(name, loader) + return f'{base}.{name}' if name else base def _find_spec(name, path, target=None): @@ -1074,9 +1252,7 @@ def _find_spec(name, path, target=None): try: find_spec = finder.find_spec except AttributeError: - spec = _find_spec_legacy(finder, name, path) - if spec is None: - continue + continue else: spec = find_spec(name, path, target) if spec is not None: @@ -1104,7 +1280,7 @@ def _find_spec(name, path, target=None): def _sanity_check(name, package, level): """Verify arguments are "sane".""" if not isinstance(name, str): - raise TypeError('module name must be str, not {}'.format(type(name))) + raise TypeError(f'module name must be str, not {type(name)}') if level < 0: raise ValueError('level must be >= 0') if level > 0: @@ -1134,13 +1310,13 @@ def _find_and_load_unlocked(name, import_): try: path = parent_module.__path__ except AttributeError: - msg = (_ERR_MSG + '; {!r} is not a package').format(name, parent) + msg = f'{_ERR_MSG_PREFIX}{name!r}; {parent!r} is not a package' raise ModuleNotFoundError(msg, name=name) from None parent_spec = parent_module.__spec__ child = name.rpartition('.')[2] spec = _find_spec(name, path) if spec is None: - raise ModuleNotFoundError(_ERR_MSG.format(name), name=name) + raise ModuleNotFoundError(f'{_ERR_MSG_PREFIX}{name!r}', name=name) else: if parent_spec: # Temporarily add child we are currently importing to parent's @@ -1185,8 +1361,7 @@ def _find_and_load(name, import_): _lock_unlock_module(name) if module is None: - message = ('import of {} halted; ' - 'None in sys.modules'.format(name)) + message = f'import of {name} halted; None in sys.modules' raise ModuleNotFoundError(message, name=name) return module @@ -1230,7 +1405,7 @@ def _handle_fromlist(module, fromlist, import_, *, recursive=False): _handle_fromlist(module, module.__all__, import_, recursive=True) elif not hasattr(module, x): - from_name = '{}.{}'.format(module.__name__, x) + from_name = f'{module.__name__}.{x}' try: _call_with_frames_removed(import_, from_name) except ModuleNotFoundError as exc: @@ -1257,7 +1432,7 @@ def _calc___package__(globals): if spec is not None and package != spec.parent: _warnings.warn("__package__ != __spec__.parent " f"({package!r} != {spec.parent!r})", - ImportWarning, stacklevel=3) + DeprecationWarning, stacklevel=3) return package elif spec is not None: return spec.parent @@ -1323,7 +1498,7 @@ def _setup(sys_module, _imp_module): modules, those two modules must be explicitly passed in. """ - global _imp, sys + global _imp, sys, _blocking_on _imp = _imp_module sys = sys_module @@ -1351,6 +1526,9 @@ def _setup(sys_module, _imp_module): builtin_module = sys.modules[builtin_name] setattr(self_module, builtin_name, builtin_module) + # Instantiation requires _weakref to have been set. + _blocking_on = _WeakValueDictionary() + def _install(sys_module, _imp_module): """Install importers for builtin and frozen modules""" diff --git a/Lib/importlib/_bootstrap_external.py b/Lib/importlib/_bootstrap_external.py index f603a89f7f..73ac4405cb 100644 --- a/Lib/importlib/_bootstrap_external.py +++ b/Lib/importlib/_bootstrap_external.py @@ -182,12 +182,22 @@ def _path_isabs(path): return path.startswith(path_separators) +def _path_abspath(path): + """Replacement for os.path.abspath.""" + if not _path_isabs(path): + for sep in path_separators: + path = path.removeprefix(f".{sep}") + return _path_join(_os.getcwd(), path) + else: + return path + + def _write_atomic(path, data, mode=0o666): """Best-effort function to write data to a path atomically. Be prepared to handle a FileExistsError if concurrent writing of the temporary file is attempted.""" # id() is used to generate a pseudo-random filename. - path_tmp = '{}.{}'.format(path, id(path)) + path_tmp = f'{path}.{id(path)}' fd = _os.open(path_tmp, _os.O_EXCL | _os.O_CREAT | _os.O_WRONLY, mode & 0o666) try: @@ -403,11 +413,45 @@ def _write_atomic(path, data, mode=0o666): # Python 3.11a7 3492 (make POP_JUMP_IF_NONE/NOT_NONE/TRUE/FALSE relative) # Python 3.11a7 3493 (Make JUMP_IF_TRUE_OR_POP/JUMP_IF_FALSE_OR_POP relative) # Python 3.11a7 3494 (New location info table) -# Python 3.11b4 3495 (Set line number of module's RESUME instr to 0 per PEP 626) -# Python 3.12 will start with magic number 3500 - +# Python 3.12a1 3500 (Remove PRECALL opcode) +# Python 3.12a1 3501 (YIELD_VALUE oparg == stack_depth) +# Python 3.12a1 3502 (LOAD_FAST_CHECK, no NULL-check in LOAD_FAST) +# Python 3.12a1 3503 (Shrink LOAD_METHOD cache) +# Python 3.12a1 3504 (Merge LOAD_METHOD back into LOAD_ATTR) +# Python 3.12a1 3505 (Specialization/Cache for FOR_ITER) +# Python 3.12a1 3506 (Add BINARY_SLICE and STORE_SLICE instructions) +# Python 3.12a1 3507 (Set lineno of module's RESUME to 0) +# Python 3.12a1 3508 (Add CLEANUP_THROW) +# Python 3.12a1 3509 (Conditional jumps only jump forward) +# Python 3.12a2 3510 (FOR_ITER leaves iterator on the stack) +# Python 3.12a2 3511 (Add STOPITERATION_ERROR instruction) +# Python 3.12a2 3512 (Remove all unused consts from code objects) +# Python 3.12a4 3513 (Add CALL_INTRINSIC_1 instruction, removed STOPITERATION_ERROR, PRINT_EXPR, IMPORT_STAR) +# Python 3.12a4 3514 (Remove ASYNC_GEN_WRAP, LIST_TO_TUPLE, and UNARY_POSITIVE) +# Python 3.12a5 3515 (Embed jump mask in COMPARE_OP oparg) +# Python 3.12a5 3516 (Add COMPARE_AND_BRANCH instruction) +# Python 3.12a5 3517 (Change YIELD_VALUE oparg to exception block depth) +# Python 3.12a6 3518 (Add RETURN_CONST instruction) +# Python 3.12a6 3519 (Modify SEND instruction) +# Python 3.12a6 3520 (Remove PREP_RERAISE_STAR, add CALL_INTRINSIC_2) +# Python 3.12a7 3521 (Shrink the LOAD_GLOBAL caches) +# Python 3.12a7 3522 (Removed JUMP_IF_FALSE_OR_POP/JUMP_IF_TRUE_OR_POP) +# Python 3.12a7 3523 (Convert COMPARE_AND_BRANCH back to COMPARE_OP) +# Python 3.12a7 3524 (Shrink the BINARY_SUBSCR caches) +# Python 3.12b1 3525 (Shrink the CALL caches) +# Python 3.12b1 3526 (Add instrumentation support) +# Python 3.12b1 3527 (Add LOAD_SUPER_ATTR) +# Python 3.12b1 3528 (Add LOAD_SUPER_ATTR_METHOD specialization) +# Python 3.12b1 3529 (Inline list/dict/set comprehensions) +# Python 3.12b1 3530 (Shrink the LOAD_SUPER_ATTR caches) +# Python 3.12b1 3531 (Add PEP 695 changes) + +# Python 3.13 will start with 3550 + +# Please don't copy-paste the same pre-release tag for new entries above!!! +# You should always use the *upcoming* tag. For example, if 3.12a6 came out +# a week ago, I should put "Python 3.12a7" next to my new magic number. -# # MAGIC must change whenever the bytecode emitted by the compiler may no # longer be understood by older implementations of the eval loop (usually # due to the addition of new opcodes). @@ -417,7 +461,7 @@ def _write_atomic(path, data, mode=0o666): # Whenever MAGIC_NUMBER is changed, the ranges in the magic_values array # in PC/launcher.c must also be updated. -MAGIC_NUMBER = (3495).to_bytes(2, 'little') + b'\r\n' +MAGIC_NUMBER = (3531).to_bytes(2, 'little') + b'\r\n' _RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c @@ -474,8 +518,8 @@ def cache_from_source(path, debug_override=None, *, optimization=None): optimization = str(optimization) if optimization != '': if not optimization.isalnum(): - raise ValueError('{!r} is not alphanumeric'.format(optimization)) - almost_filename = '{}.{}{}'.format(almost_filename, _OPT, optimization) + raise ValueError(f'{optimization!r} is not alphanumeric') + almost_filename = f'{almost_filename}.{_OPT}{optimization}' filename = almost_filename + BYTECODE_SUFFIXES[0] if sys.pycache_prefix is not None: # We need an absolute path to the py file to avoid the possibility of @@ -486,8 +530,7 @@ def cache_from_source(path, debug_override=None, *, optimization=None): # make it absolute (`C:\Somewhere\Foo\Bar`), then make it root-relative # (`Somewhere\Foo\Bar`), so we end up placing the bytecode file in an # unambiguous `C:\Bytecode\Somewhere\Foo\Bar\`. - if not _path_isabs(head): - head = _path_join(_os.getcwd(), head) + head = _path_abspath(head) # Strip initial drive from a Windows path. We know we have an absolute # path here, so the second part of the check rules out a POSIX path that @@ -619,26 +662,6 @@ def _wrap(new, old): return _check_name_wrapper -def _find_module_shim(self, fullname): - """Try to find a loader for the specified module by delegating to - self.find_loader(). - - This method is deprecated in favor of finder.find_spec(). - - """ - _warnings.warn("find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - # Call find_loader(). If it returns a string (indicating this - # is a namespace package portion), generate a warning and - # return None. - loader, portions = self.find_loader(fullname) - if loader is None and len(portions): - msg = 'Not importing directory {}: missing __init__' - _warnings.warn(msg.format(portions[0]), ImportWarning) - return loader - - def _classify_pyc(data, name, exc_details): """Perform basic validity checking of a pyc header and return the flags field, which determines how the pyc should be further validated against the source. @@ -733,7 +756,7 @@ def _compile_bytecode(data, name=None, bytecode_path=None, source_path=None): _imp._fix_co_filename(code, source_path) return code else: - raise ImportError('Non-code object in {!r}'.format(bytecode_path), + raise ImportError(f'Non-code object in {bytecode_path!r}', name=name, path=bytecode_path) @@ -800,11 +823,10 @@ def spec_from_file_location(name, location=None, *, loader=None, pass else: location = _os.fspath(location) - if not _path_isabs(location): - try: - location = _path_join(_os.getcwd(), location) - except OSError: - pass + try: + location = _path_abspath(location) + except OSError: + pass # If the location is on the filesystem, but doesn't actually exist, # we could return None here, indicating that the location is not @@ -846,6 +868,54 @@ def spec_from_file_location(name, location=None, *, loader=None, return spec +def _bless_my_loader(module_globals): + """Helper function for _warnings.c + + See GH#97850 for details. + """ + # 2022-10-06(warsaw): For now, this helper is only used in _warnings.c and + # that use case only has the module globals. This function could be + # extended to accept either that or a module object. However, in the + # latter case, it would be better to raise certain exceptions when looking + # at a module, which should have either a __loader__ or __spec__.loader. + # For backward compatibility, it is possible that we'll get an empty + # dictionary for the module globals, and that cannot raise an exception. + if not isinstance(module_globals, dict): + return None + + missing = object() + loader = module_globals.get('__loader__', None) + spec = module_globals.get('__spec__', missing) + + if loader is None: + if spec is missing: + # If working with a module: + # raise AttributeError('Module globals is missing a __spec__') + return None + elif spec is None: + raise ValueError('Module globals is missing a __spec__.loader') + + spec_loader = getattr(spec, 'loader', missing) + + if spec_loader in (missing, None): + if loader is None: + exc = AttributeError if spec_loader is missing else ValueError + raise exc('Module globals is missing a __spec__.loader') + _warnings.warn( + 'Module globals is missing a __spec__.loader', + DeprecationWarning) + spec_loader = loader + + assert spec_loader is not None + if loader is not None and loader != spec_loader: + _warnings.warn( + 'Module globals; __loader__ != __spec__.loader', + DeprecationWarning) + return loader + + return spec_loader + + # Loaders ##################################################################### class WindowsRegistryFinder: @@ -898,22 +968,6 @@ def find_spec(cls, fullname, path=None, target=None): origin=filepath) return spec - @classmethod - def find_module(cls, fullname, path=None): - """Find module named in the registry. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("WindowsRegistryFinder.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = cls.find_spec(fullname, path) - if spec is not None: - return spec.loader - else: - return None - class _LoaderBasics: @@ -935,8 +989,8 @@ def exec_module(self, module): """Execute the module.""" code = self.get_code(module.__name__) if code is None: - raise ImportError('cannot load module {!r} when get_code() ' - 'returns None'.format(module.__name__)) + raise ImportError(f'cannot load module {module.__name__!r} when ' + 'get_code() returns None') _bootstrap._call_with_frames_removed(exec, code, module.__dict__) def load_module(self, fullname): @@ -1077,7 +1131,8 @@ def get_code(self, fullname): source_mtime is not None): if hash_based: if source_hash is None: - source_hash = _imp.source_hash(source_bytes) + source_hash = _imp.source_hash(_RAW_MAGIC_NUMBER, + source_bytes) data = _code_to_hash_pyc(code_object, source_hash, check_source) else: data = _code_to_timestamp_pyc(code_object, source_mtime, @@ -1321,7 +1376,7 @@ def __len__(self): return len(self._recalculate()) def __repr__(self): - return '_NamespacePath({!r})'.format(self._path) + return f'_NamespacePath({self._path!r})' def __contains__(self, item): return item in self._recalculate() @@ -1332,22 +1387,11 @@ def append(self, item): # This class is actually exposed publicly in a namespace package's __loader__ # attribute, so it should be available through a non-private name. -# https://bugs.python.org/issue35673 +# https://github.com/python/cpython/issues/92054 class NamespaceLoader: def __init__(self, name, path, path_finder): self._path = _NamespacePath(name, path, path_finder) - @staticmethod - def module_repr(module): - """Return repr for the module. - - The method is deprecated. The import machinery does the job itself. - - """ - _warnings.warn("NamespaceLoader.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - return ''.format(module.__name__) - def is_package(self, fullname): return True @@ -1440,27 +1484,6 @@ def _path_importer_cache(cls, path): sys.path_importer_cache[path] = finder return finder - @classmethod - def _legacy_get_spec(cls, fullname, finder): - # This would be a good place for a DeprecationWarning if - # we ended up going that route. - if hasattr(finder, 'find_loader'): - msg = (f"{_bootstrap._object_name(finder)}.find_spec() not found; " - "falling back to find_loader()") - _warnings.warn(msg, ImportWarning) - loader, portions = finder.find_loader(fullname) - else: - msg = (f"{_bootstrap._object_name(finder)}.find_spec() not found; " - "falling back to find_module()") - _warnings.warn(msg, ImportWarning) - loader = finder.find_module(fullname) - portions = [] - if loader is not None: - return _bootstrap.spec_from_loader(fullname, loader) - spec = _bootstrap.ModuleSpec(fullname, None) - spec.submodule_search_locations = portions - return spec - @classmethod def _get_spec(cls, fullname, path, target=None): """Find the loader or namespace_path for this module/package name.""" @@ -1472,10 +1495,7 @@ def _get_spec(cls, fullname, path, target=None): continue finder = cls._path_importer_cache(entry) if finder is not None: - if hasattr(finder, 'find_spec'): - spec = finder.find_spec(fullname, target) - else: - spec = cls._legacy_get_spec(fullname, finder) + spec = finder.find_spec(fullname, target) if spec is None: continue if spec.loader is not None: @@ -1517,22 +1537,6 @@ def find_spec(cls, fullname, path=None, target=None): else: return spec - @classmethod - def find_module(cls, fullname, path=None): - """find the module on sys.path or 'path' based on sys.path_hooks and - sys.path_importer_cache. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("PathFinder.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = cls.find_spec(fullname, path) - if spec is None: - return None - return spec.loader - @staticmethod def find_distributions(*args, **kwargs): """ @@ -1567,10 +1571,8 @@ def __init__(self, path, *loader_details): # Base (directory) path if not path or path == '.': self.path = _os.getcwd() - elif not _path_isabs(path): - self.path = _path_join(_os.getcwd(), path) else: - self.path = path + self.path = _path_abspath(path) self._path_mtime = -1 self._path_cache = set() self._relaxed_path_cache = set() @@ -1579,23 +1581,6 @@ def invalidate_caches(self): """Invalidate the directory mtime.""" self._path_mtime = -1 - find_module = _find_module_shim - - def find_loader(self, fullname): - """Try to find a loader for the specified module, or the namespace - package portions. Returns (loader, list-of-portions). - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("FileFinder.find_loader() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = self.find_spec(fullname) - if spec is None: - return None, [] - return spec.loader, spec.submodule_search_locations or [] - def _get_spec(self, loader_class, fullname, path, smsl, target): loader = loader_class(fullname, path) return spec_from_file_location(fullname, path, loader=loader, @@ -1675,7 +1660,7 @@ def _fill_cache(self): for item in contents: name, dot, suffix = item.partition('.') if dot: - new_name = '{}.{}'.format(name, suffix.lower()) + new_name = f'{name}.{suffix.lower()}' else: new_name = name lower_suffix_contents.add(new_name) @@ -1702,7 +1687,7 @@ def path_hook_for_FileFinder(path): return path_hook_for_FileFinder def __repr__(self): - return 'FileFinder({!r})'.format(self.path) + return f'FileFinder({self.path!r})' # Import setup ############################################################### @@ -1720,6 +1705,8 @@ def _fix_up_module(ns, name, pathname, cpathname=None): loader = SourceFileLoader(name, pathname) if not spec: spec = spec_from_file_location(name, pathname, loader=loader) + if cpathname: + spec.cached = _path_abspath(cpathname) try: ns['__spec__'] = spec ns['__loader__'] = loader diff --git a/Lib/importlib/abc.py b/Lib/importlib/abc.py index 3fa151f390..b56fa94eb9 100644 --- a/Lib/importlib/abc.py +++ b/Lib/importlib/abc.py @@ -15,20 +15,29 @@ import abc import warnings -# for compatibility with Python 3.10 -from .resources.abc import ResourceReader, Traversable, TraversableResources +from .resources import abc as _resources_abc __all__ = [ - 'Loader', 'Finder', 'MetaPathFinder', 'PathEntryFinder', + 'Loader', 'MetaPathFinder', 'PathEntryFinder', 'ResourceLoader', 'InspectLoader', 'ExecutionLoader', 'FileLoader', 'SourceLoader', - - # for compatibility with Python 3.10 - 'ResourceReader', 'Traversable', 'TraversableResources', ] +def __getattr__(name): + """ + For backwards compatibility, continue to make names + from _resources_abc available through this module. #93963 + """ + if name in _resources_abc.__all__: + obj = getattr(_resources_abc, name) + warnings._deprecated(f"{__name__}.{name}", remove=(3, 14)) + globals()[name] = obj + return obj + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + + def _register(abstract_cls, *classes): for cls in classes: abstract_cls.register(cls) @@ -40,38 +49,6 @@ def _register(abstract_cls, *classes): abstract_cls.register(frozen_cls) -class Finder(metaclass=abc.ABCMeta): - - """Legacy abstract base class for import finders. - - It may be subclassed for compatibility with legacy third party - reimplementations of the import system. Otherwise, finder - implementations should derive from the more specific MetaPathFinder - or PathEntryFinder ABCs. - - Deprecated since Python 3.3 - """ - - def __init__(self): - warnings.warn("the Finder ABC is deprecated and " - "slated for removal in Python 3.12; use MetaPathFinder " - "or PathEntryFinder instead", - DeprecationWarning) - - @abc.abstractmethod - def find_module(self, fullname, path=None): - """An abstract method that should find a module. - The fullname is a str and the optional path is a str or None. - Returns a Loader object or None. - """ - warnings.warn("importlib.abc.Finder along with its find_module() " - "method are deprecated and " - "slated for removal in Python 3.12; use " - "MetaPathFinder.find_spec() or " - "PathEntryFinder.find_spec() instead", - DeprecationWarning) - - class MetaPathFinder(metaclass=abc.ABCMeta): """Abstract base class for import finders on sys.meta_path.""" @@ -79,27 +56,6 @@ class MetaPathFinder(metaclass=abc.ABCMeta): # We don't define find_spec() here since that would break # hasattr checks we do to support backward compatibility. - def find_module(self, fullname, path): - """Return a loader for the module. - - If no module is found, return None. The fullname is a str and - the path is a list of strings or None. - - This method is deprecated since Python 3.4 in favor of - finder.find_spec(). If find_spec() exists then backwards-compatible - functionality is provided for this method. - - """ - warnings.warn("MetaPathFinder.find_module() is deprecated since Python " - "3.4 in favor of MetaPathFinder.find_spec() and is " - "slated for removal in Python 3.12", - DeprecationWarning, - stacklevel=2) - if not hasattr(self, 'find_spec'): - return None - found = self.find_spec(fullname, path) - return found.loader if found is not None else None - def invalidate_caches(self): """An optional method for clearing the finder's cache, if any. This method is used by importlib.invalidate_caches(). @@ -113,43 +69,6 @@ class PathEntryFinder(metaclass=abc.ABCMeta): """Abstract base class for path entry finders used by PathFinder.""" - # We don't define find_spec() here since that would break - # hasattr checks we do to support backward compatibility. - - def find_loader(self, fullname): - """Return (loader, namespace portion) for the path entry. - - The fullname is a str. The namespace portion is a sequence of - path entries contributing to part of a namespace package. The - sequence may be empty. If loader is not None, the portion will - be ignored. - - The portion will be discarded if another path entry finder - locates the module as a normal module or package. - - This method is deprecated since Python 3.4 in favor of - finder.find_spec(). If find_spec() is provided than backwards-compatible - functionality is provided. - """ - warnings.warn("PathEntryFinder.find_loader() is deprecated since Python " - "3.4 in favor of PathEntryFinder.find_spec() " - "(available since 3.4)", - DeprecationWarning, - stacklevel=2) - if not hasattr(self, 'find_spec'): - return None, [] - found = self.find_spec(fullname) - if found is not None: - if not found.submodule_search_locations: - portions = [] - else: - portions = found.submodule_search_locations - return found.loader, portions - else: - return None, [] - - find_module = _bootstrap_external._find_module_shim - def invalidate_caches(self): """An optional method for clearing the finder's cache, if any. This method is used by PathFinder.invalidate_caches(). diff --git a/Lib/importlib/metadata/__init__.py b/Lib/importlib/metadata/__init__.py index 68828269fc..56ee403832 100644 --- a/Lib/importlib/metadata/__init__.py +++ b/Lib/importlib/metadata/__init__.py @@ -12,7 +12,9 @@ import functools import itertools import posixpath +import contextlib import collections +import inspect from . import _adapters, _meta from ._collections import FreezableDefaultDict, Pair @@ -24,7 +26,7 @@ from importlib import import_module from importlib.abc import MetaPathFinder from itertools import starmap -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, cast __all__ = [ @@ -140,6 +142,7 @@ class DeprecatedTuple: 1 """ + # Do not remove prior to 2023-05-01 or Python 3.13 _warn = functools.partial( warnings.warn, "EntryPoint tuple interface is deprecated. Access members by name.", @@ -228,17 +231,6 @@ def _for(self, dist): vars(self).update(dist=dist) return self - def __iter__(self): - """ - Supply iter so one may construct dicts of EntryPoints by name. - """ - msg = ( - "Construction of dict of EntryPoints is deprecated in " - "favor of EntryPoints." - ) - warnings.warn(msg, DeprecationWarning) - return iter((self.name, self)) - def matches(self, **params): """ EntryPoint matches the given parameters. @@ -284,77 +276,7 @@ def __hash__(self): return hash(self._key()) -class DeprecatedList(list): - """ - Allow an otherwise immutable object to implement mutability - for compatibility. - - >>> recwarn = getfixture('recwarn') - >>> dl = DeprecatedList(range(3)) - >>> dl[0] = 1 - >>> dl.append(3) - >>> del dl[3] - >>> dl.reverse() - >>> dl.sort() - >>> dl.extend([4]) - >>> dl.pop(-1) - 4 - >>> dl.remove(1) - >>> dl += [5] - >>> dl + [6] - [1, 2, 5, 6] - >>> dl + (6,) - [1, 2, 5, 6] - >>> dl.insert(0, 0) - >>> dl - [0, 1, 2, 5] - >>> dl == [0, 1, 2, 5] - True - >>> dl == (0, 1, 2, 5) - True - >>> len(recwarn) - 1 - """ - - __slots__ = () - - _warn = functools.partial( - warnings.warn, - "EntryPoints list interface is deprecated. Cast to list if needed.", - DeprecationWarning, - stacklevel=2, - ) - - def _wrap_deprecated_method(method_name: str): # type: ignore - def wrapped(self, *args, **kwargs): - self._warn() - return getattr(super(), method_name)(*args, **kwargs) - - return method_name, wrapped - - locals().update( - map( - _wrap_deprecated_method, - '__setitem__ __delitem__ append reverse extend pop remove ' - '__iadd__ insert sort'.split(), - ) - ) - - def __add__(self, other): - if not isinstance(other, tuple): - self._warn() - other = tuple(other) - return self.__class__(tuple(self) + other) - - def __eq__(self, other): - if not isinstance(other, tuple): - self._warn() - other = tuple(other) - - return tuple(self).__eq__(other) - - -class EntryPoints(DeprecatedList): +class EntryPoints(tuple): """ An immutable collection of selectable EntryPoint objects. """ @@ -365,14 +287,6 @@ def __getitem__(self, name): # -> EntryPoint: """ Get the EntryPoint in self matching name. """ - if isinstance(name, int): - warnings.warn( - "Accessing entry points by index is deprecated. " - "Cast to tuple if needed.", - DeprecationWarning, - stacklevel=2, - ) - return super().__getitem__(name) try: return next(iter(self.select(name=name))) except StopIteration: @@ -396,10 +310,6 @@ def names(self): def groups(self): """ Return the set of all groups of all entry points. - - For coverage while SelectableGroups is present. - >>> EntryPoints().groups - set() """ return {ep.group for ep in self} @@ -415,101 +325,6 @@ def _from_text(text): ) -class Deprecated: - """ - Compatibility add-in for mapping to indicate that - mapping behavior is deprecated. - - >>> recwarn = getfixture('recwarn') - >>> class DeprecatedDict(Deprecated, dict): pass - >>> dd = DeprecatedDict(foo='bar') - >>> dd.get('baz', None) - >>> dd['foo'] - 'bar' - >>> list(dd) - ['foo'] - >>> list(dd.keys()) - ['foo'] - >>> 'foo' in dd - True - >>> list(dd.values()) - ['bar'] - >>> len(recwarn) - 1 - """ - - _warn = functools.partial( - warnings.warn, - "SelectableGroups dict interface is deprecated. Use select.", - DeprecationWarning, - stacklevel=2, - ) - - def __getitem__(self, name): - self._warn() - return super().__getitem__(name) - - def get(self, name, default=None): - self._warn() - return super().get(name, default) - - def __iter__(self): - self._warn() - return super().__iter__() - - def __contains__(self, *args): - self._warn() - return super().__contains__(*args) - - def keys(self): - self._warn() - return super().keys() - - def values(self): - self._warn() - return super().values() - - -class SelectableGroups(Deprecated, dict): - """ - A backward- and forward-compatible result from - entry_points that fully implements the dict interface. - """ - - @classmethod - def load(cls, eps): - by_group = operator.attrgetter('group') - ordered = sorted(eps, key=by_group) - grouped = itertools.groupby(ordered, by_group) - return cls((group, EntryPoints(eps)) for group, eps in grouped) - - @property - def _all(self): - """ - Reconstruct a list of all entrypoints from the groups. - """ - groups = super(Deprecated, self).values() - return EntryPoints(itertools.chain.from_iterable(groups)) - - @property - def groups(self): - return self._all.groups - - @property - def names(self): - """ - for coverage: - >>> SelectableGroups().names - set() - """ - return self._all.names - - def select(self, **params): - if not params: - return self - return self._all.select(**params) - - class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" @@ -534,11 +349,30 @@ def __repr__(self): return f'' -class Distribution: +class DeprecatedNonAbstract: + def __new__(cls, *args, **kwargs): + all_names = { + name for subclass in inspect.getmro(cls) for name in vars(subclass) + } + abstract = { + name + for name in all_names + if getattr(getattr(cls, name), '__isabstractmethod__', False) + } + if abstract: + warnings.warn( + f"Unimplemented abstract methods {abstract}", + DeprecationWarning, + stacklevel=2, + ) + return super().__new__(cls) + + +class Distribution(DeprecatedNonAbstract): """A Python distribution package.""" @abc.abstractmethod - def read_text(self, filename): + def read_text(self, filename) -> Optional[str]: """Attempt to load metadata file given by the name. :param filename: The name of the file in the distribution info. @@ -612,7 +446,7 @@ def metadata(self) -> _meta.PackageMetadata: The returned object will have keys that name the various bits of metadata. See PEP 566 for details. """ - text = ( + opt_text = ( self.read_text('METADATA') or self.read_text('PKG-INFO') # This last clause is here to support old egg-info files. Its @@ -620,6 +454,7 @@ def metadata(self) -> _meta.PackageMetadata: # (which points to the egg-info file) attribute unchanged. or self.read_text('') ) + text = cast(str, opt_text) return _adapters.Message(email.message_from_string(text)) @property @@ -648,8 +483,8 @@ def files(self): :return: List of PackagePath for this distribution or None Result is `None` if the metadata file that enumerates files - (i.e. RECORD for dist-info or SOURCES.txt for egg-info) is - missing. + (i.e. RECORD for dist-info, or installed-files.txt or + SOURCES.txt for egg-info) is missing. Result may be empty if the metadata exists but is empty. """ @@ -662,9 +497,19 @@ def make_file(name, hash=None, size_str=None): @pass_none def make_files(lines): - return list(starmap(make_file, csv.reader(lines))) + return starmap(make_file, csv.reader(lines)) - return make_files(self._read_files_distinfo() or self._read_files_egginfo()) + @pass_none + def skip_missing_files(package_paths): + return list(filter(lambda path: path.locate().exists(), package_paths)) + + return skip_missing_files( + make_files( + self._read_files_distinfo() + or self._read_files_egginfo_installed() + or self._read_files_egginfo_sources() + ) + ) def _read_files_distinfo(self): """ @@ -673,10 +518,45 @@ def _read_files_distinfo(self): text = self.read_text('RECORD') return text and text.splitlines() - def _read_files_egginfo(self): + def _read_files_egginfo_installed(self): + """ + Read installed-files.txt and return lines in a similar + CSV-parsable format as RECORD: each file must be placed + relative to the site-packages directory and must also be + quoted (since file names can contain literal commas). + + This file is written when the package is installed by pip, + but it might not be written for other installation methods. + Assume the file is accurate if it exists. """ - SOURCES.txt might contain literal commas, so wrap each line - in quotes. + text = self.read_text('installed-files.txt') + # Prepend the .egg-info/ subdir to the lines in this file. + # But this subdir is only available from PathDistribution's + # self._path. + subdir = getattr(self, '_path', None) + if not text or not subdir: + return + + paths = ( + (subdir / name) + .resolve() + .relative_to(self.locate_file('').resolve()) + .as_posix() + for name in text.splitlines() + ) + return map('"{}"'.format, paths) + + def _read_files_egginfo_sources(self): + """ + Read SOURCES.txt and return lines in a similar CSV-parsable + format as RECORD: each file name must be quoted (since it + might contain literal commas). + + Note that SOURCES.txt is not a reliable source for what + files are installed by a package. This file is generated + for a source archive, and the files that are present + there (e.g. setup.py) may not correctly reflect the files + that are present after the package has been installed. """ text = self.read_text('SOURCES.txt') return text and map('"{}"'.format, text.splitlines()) @@ -1023,27 +903,19 @@ def version(distribution_name): """ -def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: +def entry_points(**params) -> EntryPoints: """Return EntryPoint objects for all installed packages. Pass selection parameters (group or name) to filter the result to entry points matching those properties (see EntryPoints.select()). - For compatibility, returns ``SelectableGroups`` object unless - selection parameters are supplied. In the future, this function - will return ``EntryPoints`` instead of ``SelectableGroups`` - even when no selection parameters are supplied. - - For maximum future compatibility, pass selection parameters - or invoke ``.select`` with parameters on the result. - - :return: EntryPoints or SelectableGroups for all installed packages. + :return: EntryPoints for all installed packages. """ eps = itertools.chain.from_iterable( dist.entry_points for dist in _unique(distributions()) ) - return SelectableGroups.load(eps).select(**params) + return EntryPoints(eps).select(**params) def files(distribution_name): @@ -1087,8 +959,13 @@ def _top_level_declared(dist): def _top_level_inferred(dist): - return { - f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name + opt_names = { + f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in always_iterable(dist.files) - if f.suffix == ".py" } + + @pass_none + def importable_name(name): + return '.' not in name + + return filter(importable_name, opt_names) diff --git a/Lib/importlib/metadata/_adapters.py b/Lib/importlib/metadata/_adapters.py index aa460d3eda..6aed69a308 100644 --- a/Lib/importlib/metadata/_adapters.py +++ b/Lib/importlib/metadata/_adapters.py @@ -1,3 +1,5 @@ +import functools +import warnings import re import textwrap import email.message @@ -5,6 +7,15 @@ from ._text import FoldedCase +# Do not remove prior to 2024-01-01 or Python 3.14 +_warn = functools.partial( + warnings.warn, + "Implicit None on return values is deprecated and will raise KeyErrors.", + DeprecationWarning, + stacklevel=2, +) + + class Message(email.message.Message): multiple_use_keys = set( map( @@ -39,6 +50,16 @@ def __init__(self, *args, **kwargs): def __iter__(self): return super().__iter__() + def __getitem__(self, item): + """ + Warn users that a ``KeyError`` can be expected when a + mising key is supplied. Ref python/importlib_metadata#371. + """ + res = super().__getitem__(item) + if res is None: + _warn() + return res + def _repair_headers(self): def redent(value): "Correct for RFC822 indentation" diff --git a/Lib/importlib/metadata/_meta.py b/Lib/importlib/metadata/_meta.py index d5c0576194..c9a7ef906a 100644 --- a/Lib/importlib/metadata/_meta.py +++ b/Lib/importlib/metadata/_meta.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Iterator, List, Protocol, TypeVar, Union +from typing import Protocol +from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union, overload _T = TypeVar("_T") @@ -17,7 +18,21 @@ def __getitem__(self, key: str) -> str: def __iter__(self) -> Iterator[str]: ... # pragma: no cover - def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: + @overload + def get(self, name: str, failobj: None = None) -> Optional[str]: + ... # pragma: no cover + + @overload + def get(self, name: str, failobj: _T) -> Union[str, _T]: + ... # pragma: no cover + + # overload per python/importlib_metadata#435 + @overload + def get_all(self, name: str, failobj: None = None) -> Optional[List[Any]]: + ... # pragma: no cover + + @overload + def get_all(self, name: str, failobj: _T) -> Union[List[Any], _T]: """ Return all values associated with a possibly multi-valued key. """ @@ -29,18 +44,19 @@ def json(self) -> Dict[str, Union[str, List[str]]]: """ -class SimplePath(Protocol): +class SimplePath(Protocol[_T]): """ A minimal subset of pathlib.Path required by PathDistribution. """ - def joinpath(self) -> 'SimplePath': + def joinpath(self) -> _T: ... # pragma: no cover - def __truediv__(self) -> 'SimplePath': + def __truediv__(self, other: Union[str, _T]) -> _T: ... # pragma: no cover - def parent(self) -> 'SimplePath': + @property + def parent(self) -> _T: ... # pragma: no cover def read_text(self) -> str: diff --git a/Lib/importlib/resources/_adapters.py b/Lib/importlib/resources/_adapters.py index ea363d86a5..50688fbb66 100644 --- a/Lib/importlib/resources/_adapters.py +++ b/Lib/importlib/resources/_adapters.py @@ -34,9 +34,7 @@ def _io_wrapper(file, mode='r', *args, **kwargs): return TextIOWrapper(file, *args, **kwargs) elif mode == 'rb': return file - raise ValueError( - "Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode) - ) + raise ValueError(f"Invalid mode value '{mode}', only 'r' and 'rb' are supported") class CompatibilityFiles: diff --git a/Lib/importlib/resources/_common.py b/Lib/importlib/resources/_common.py index ca1fa8ab2f..a390253534 100644 --- a/Lib/importlib/resources/_common.py +++ b/Lib/importlib/resources/_common.py @@ -5,25 +5,58 @@ import contextlib import types import importlib +import inspect +import warnings +import itertools -from typing import Union, Optional +from typing import Union, Optional, cast from .abc import ResourceReader, Traversable from ._adapters import wrap_spec Package = Union[types.ModuleType, str] +Anchor = Package -def files(package): - # type: (Package) -> Traversable +def package_to_anchor(func): """ - Get a Traversable resource from a package + Replace 'package' parameter as 'anchor' and warn about the change. + + Other errors should fall through. + + >>> files('a', 'b') + Traceback (most recent call last): + TypeError: files() takes from 0 to 1 positional arguments but 2 were given + """ + undefined = object() + + @functools.wraps(func) + def wrapper(anchor=undefined, package=undefined): + if package is not undefined: + if anchor is not undefined: + return func(anchor, package) + warnings.warn( + "First parameter to files is renamed to 'anchor'", + DeprecationWarning, + stacklevel=2, + ) + return func(package) + elif anchor is undefined: + return func() + return func(anchor) + + return wrapper + + +@package_to_anchor +def files(anchor: Optional[Anchor] = None) -> Traversable: + """ + Get a Traversable resource for an anchor. """ - return from_package(get_package(package)) + return from_package(resolve(anchor)) -def get_resource_reader(package): - # type: (types.ModuleType) -> Optional[ResourceReader] +def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]: """ Return the package's loader if it's a ResourceReader. """ @@ -39,24 +72,39 @@ def get_resource_reader(package): return reader(spec.name) # type: ignore -def resolve(cand): - # type: (Package) -> types.ModuleType - return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand) +@functools.singledispatch +def resolve(cand: Optional[Anchor]) -> types.ModuleType: + return cast(types.ModuleType, cand) + + +@resolve.register +def _(cand: str) -> types.ModuleType: + return importlib.import_module(cand) + +@resolve.register +def _(cand: None) -> types.ModuleType: + return resolve(_infer_caller().f_globals['__name__']) -def get_package(package): - # type: (Package) -> types.ModuleType - """Take a package name or module object and return the module. - Raise an exception if the resolved module is not a package. +def _infer_caller(): """ - resolved = resolve(package) - if wrap_spec(resolved).submodule_search_locations is None: - raise TypeError(f'{package!r} is not a package') - return resolved + Walk the stack and find the frame of the first caller not in this module. + """ + + def is_this_file(frame_info): + return frame_info.filename == __file__ + + def is_wrapper(frame_info): + return frame_info.function == 'wrapper' + + not_this_file = itertools.filterfalse(is_this_file, inspect.stack()) + # also exclude 'wrapper' due to singledispatch in the call stack + callers = itertools.filterfalse(is_wrapper, not_this_file) + return next(callers).frame -def from_package(package): +def from_package(package: types.ModuleType): """ Return a Traversable object for the given package. @@ -67,10 +115,14 @@ def from_package(package): @contextlib.contextmanager -def _tempfile(reader, suffix='', - # gh-93353: Keep a reference to call os.remove() in late Python - # finalization. - *, _os_remove=os.remove): +def _tempfile( + reader, + suffix='', + # gh-93353: Keep a reference to call os.remove() in late Python + # finalization. + *, + _os_remove=os.remove, +): # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try' # blocks due to the need to close the temporary file to work on Windows # properly. @@ -89,13 +141,30 @@ def _tempfile(reader, suffix='', pass +def _temp_file(path): + return _tempfile(path.read_bytes, suffix=path.name) + + +def _is_present_dir(path: Traversable) -> bool: + """ + Some Traversables implement ``is_dir()`` to raise an + exception (i.e. ``FileNotFoundError``) when the + directory doesn't exist. This function wraps that call + to always return a boolean and only return True + if there's a dir and it exists. + """ + with contextlib.suppress(FileNotFoundError): + return path.is_dir() + return False + + @functools.singledispatch def as_file(path): """ Given a Traversable object, return that object as a path on the local file system in a context manager. """ - return _tempfile(path.read_bytes, suffix=path.name) + return _temp_dir(path) if _is_present_dir(path) else _temp_file(path) @as_file.register(pathlib.Path) @@ -105,3 +174,34 @@ def _(path): Degenerate behavior for pathlib.Path objects. """ yield path + + +@contextlib.contextmanager +def _temp_path(dir: tempfile.TemporaryDirectory): + """ + Wrap tempfile.TemporyDirectory to return a pathlib object. + """ + with dir as result: + yield pathlib.Path(result) + + +@contextlib.contextmanager +def _temp_dir(path): + """ + Given a traversable dir, recursively replicate the whole tree + to the file system in a context manager. + """ + assert path.is_dir() + with _temp_path(tempfile.TemporaryDirectory()) as temp_dir: + yield _write_contents(temp_dir, path) + + +def _write_contents(target, source): + child = target.joinpath(source.name) + if source.is_dir(): + child.mkdir() + for item in source.iterdir(): + _write_contents(child, item) + else: + child.write_bytes(source.read_bytes()) + return child diff --git a/Lib/importlib/resources/_itertools.py b/Lib/importlib/resources/_itertools.py index cce05582ff..7b775ef5ae 100644 --- a/Lib/importlib/resources/_itertools.py +++ b/Lib/importlib/resources/_itertools.py @@ -1,35 +1,38 @@ -from itertools import filterfalse +# from more_itertools 9.0 +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) -from typing import ( - Callable, - Iterable, - Iterator, - Optional, - Set, - TypeVar, - Union, -) - -# Type and type variable definitions -_T = TypeVar('_T') -_U = TypeVar('_U') - - -def unique_everseen( - iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = None -) -> Iterator[_T]: - "List unique elements, preserving order. Remember all elements ever seen." - # unique_everseen('AAAABBBCCDAABBB') --> A B C D - # unique_everseen('ABBCcAD', str.lower) --> A B C D - seen: Set[Union[_T, _U]] = set() - seen_add = seen.add - if key is None: - for element in filterfalse(seen.__contains__, iterable): - seen_add(element) - yield element + try: + second_value = next(it) + except StopIteration: + pass else: - for element in iterable: - k = key(element) - if k not in seen: - seen_add(k) - yield element + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value diff --git a/Lib/importlib/resources/_legacy.py b/Lib/importlib/resources/_legacy.py index 1d5d3f1fbb..b1ea8105da 100644 --- a/Lib/importlib/resources/_legacy.py +++ b/Lib/importlib/resources/_legacy.py @@ -27,8 +27,7 @@ def wrapper(*args, **kwargs): return wrapper -def normalize_path(path): - # type: (Any) -> str +def normalize_path(path: Any) -> str: """Normalize a path by ensuring it is a string. If the resulting string contains path separators, an exception is raised. diff --git a/Lib/importlib/resources/abc.py b/Lib/importlib/resources/abc.py index 0b7bfdc415..6750a7aaf1 100644 --- a/Lib/importlib/resources/abc.py +++ b/Lib/importlib/resources/abc.py @@ -1,6 +1,8 @@ import abc import io +import itertools import os +import pathlib from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional from typing import runtime_checkable, Protocol from typing import Union @@ -53,6 +55,10 @@ def contents(self) -> Iterable[str]: raise FileNotFoundError +class TraversalError(Exception): + pass + + @runtime_checkable class Traversable(Protocol): """ @@ -95,7 +101,6 @@ def is_file(self) -> bool: Return True if self is a file """ - @abc.abstractmethod def joinpath(self, *descendants: StrPath) -> "Traversable": """ Return Traversable resolved with any descendants applied. @@ -104,6 +109,22 @@ def joinpath(self, *descendants: StrPath) -> "Traversable": and each may contain multiple levels separated by ``posixpath.sep`` (``/``). """ + if not descendants: + return self + names = itertools.chain.from_iterable( + path.parts for path in map(pathlib.PurePosixPath, descendants) + ) + target = next(names) + matches = ( + traversable for traversable in self.iterdir() if traversable.name == target + ) + try: + match = next(matches) + except StopIteration: + raise TraversalError( + "Target not found during traversal.", target, list(names) + ) + return match.joinpath(*names) def __truediv__(self, child: StrPath) -> "Traversable": """ @@ -121,7 +142,8 @@ def open(self, mode='r', *args, **kwargs): accepted by io.TextIOWrapper. """ - @abc.abstractproperty + @property + @abc.abstractmethod def name(self) -> str: """ The base name of this object without any parent references. diff --git a/Lib/importlib/resources/readers.py b/Lib/importlib/resources/readers.py index b470a2062b..c3cdf769cb 100644 --- a/Lib/importlib/resources/readers.py +++ b/Lib/importlib/resources/readers.py @@ -1,11 +1,12 @@ import collections -import operator +import itertools import pathlib +import operator import zipfile from . import abc -from ._itertools import unique_everseen +from ._itertools import only def remove_duplicates(items): @@ -41,8 +42,10 @@ def open_resource(self, resource): raise FileNotFoundError(exc.args[0]) def is_resource(self, path): - # workaround for `zipfile.Path.is_file` returning true - # for non-existent paths. + """ + Workaround for `zipfile.Path.is_file` returning true + for non-existent paths. + """ target = self.files().joinpath(path) return target.is_file() and target.exists() @@ -67,8 +70,10 @@ def __init__(self, *paths): raise NotADirectoryError('MultiplexedPath only supports directories') def iterdir(self): - files = (file for path in self._paths for file in path.iterdir()) - return unique_everseen(files, key=operator.attrgetter('name')) + children = (child for path in self._paths for child in path.iterdir()) + by_name = operator.attrgetter('name') + groups = itertools.groupby(sorted(children, key=by_name), key=by_name) + return map(self._follow, (locs for name, locs in groups)) def read_bytes(self): raise FileNotFoundError(f'{self} is not a file') @@ -82,15 +87,32 @@ def is_dir(self): def is_file(self): return False - def joinpath(self, child): - # first try to find child in current paths - for file in self.iterdir(): - if file.name == child: - return file - # if it does not exist, construct it with the first path - return self._paths[0] / child + def joinpath(self, *descendants): + try: + return super().joinpath(*descendants) + except abc.TraversalError: + # One of the paths did not resolve (a directory does not exist). + # Just return something that will not exist. + return self._paths[0].joinpath(*descendants) + + @classmethod + def _follow(cls, children): + """ + Construct a MultiplexedPath if needed. + + If children contains a sole element, return it. + Otherwise, return a MultiplexedPath of the items. + Unless one of the items is not a Directory, then return the first. + """ + subdirs, one_dir, one_file = itertools.tee(children, 3) - __truediv__ = joinpath + try: + return only(one_dir) + except ValueError: + try: + return cls(*subdirs) + except NotADirectoryError: + return next(one_file) def open(self, *args, **kwargs): raise FileNotFoundError(f'{self} is not a file') diff --git a/Lib/importlib/resources/simple.py b/Lib/importlib/resources/simple.py index d0fbf23776..7770c922c8 100644 --- a/Lib/importlib/resources/simple.py +++ b/Lib/importlib/resources/simple.py @@ -16,31 +16,28 @@ class SimpleReader(abc.ABC): provider. """ - @abc.abstractproperty - def package(self): - # type: () -> str + @property + @abc.abstractmethod + def package(self) -> str: """ The name of the package for which this reader loads resources. """ @abc.abstractmethod - def children(self): - # type: () -> List['SimpleReader'] + def children(self) -> List['SimpleReader']: """ Obtain an iterable of SimpleReader for available child containers (e.g. directories). """ @abc.abstractmethod - def resources(self): - # type: () -> List[str] + def resources(self) -> List[str]: """ Obtain available named resources for this virtual package. """ @abc.abstractmethod - def open_binary(self, resource): - # type: (str) -> BinaryIO + def open_binary(self, resource: str) -> BinaryIO: """ Obtain a File-like for a named resource. """ @@ -50,13 +47,35 @@ def name(self): return self.package.split('.')[-1] +class ResourceContainer(Traversable): + """ + Traversable container for a package's resources via its reader. + """ + + def __init__(self, reader: SimpleReader): + self.reader = reader + + def is_dir(self): + return True + + def is_file(self): + return False + + def iterdir(self): + files = (ResourceHandle(self, name) for name in self.reader.resources) + dirs = map(ResourceContainer, self.reader.children()) + return itertools.chain(files, dirs) + + def open(self, *args, **kwargs): + raise IsADirectoryError() + + class ResourceHandle(Traversable): """ Handle to a named resource in a ResourceReader. """ - def __init__(self, parent, name): - # type: (ResourceContainer, str) -> None + def __init__(self, parent: ResourceContainer, name: str): self.parent = parent self.name = name # type: ignore @@ -76,44 +95,6 @@ def joinpath(self, name): raise RuntimeError("Cannot traverse into a resource") -class ResourceContainer(Traversable): - """ - Traversable container for a package's resources via its reader. - """ - - def __init__(self, reader): - # type: (SimpleReader) -> None - self.reader = reader - - def is_dir(self): - return True - - def is_file(self): - return False - - def iterdir(self): - files = (ResourceHandle(self, name) for name in self.reader.resources) - dirs = map(ResourceContainer, self.reader.children()) - return itertools.chain(files, dirs) - - def open(self, *args, **kwargs): - raise IsADirectoryError() - - @staticmethod - def _flatten(compound_names): - for name in compound_names: - yield from name.split('/') - - def joinpath(self, *descendants): - if not descendants: - return self - names = self._flatten(descendants) - target = next(names) - return next( - traversable for traversable in self.iterdir() if traversable.name == target - ).joinpath(*names) - - class TraversableReader(TraversableResources, SimpleReader): """ A TraversableResources based on SimpleReader. Resource providers diff --git a/Lib/importlib/util.py b/Lib/importlib/util.py index 8623c89840..f4d6e82331 100644 --- a/Lib/importlib/util.py +++ b/Lib/importlib/util.py @@ -11,12 +11,9 @@ from ._bootstrap_external import source_from_cache from ._bootstrap_external import spec_from_file_location -from contextlib import contextmanager import _imp -import functools import sys import types -import warnings def source_hash(source_bytes): @@ -63,10 +60,10 @@ def _find_spec_from_path(name, path=None): try: spec = module.__spec__ except AttributeError: - raise ValueError('{}.__spec__ is not set'.format(name)) from None + raise ValueError(f'{name}.__spec__ is not set') from None else: if spec is None: - raise ValueError('{}.__spec__ is None'.format(name)) + raise ValueError(f'{name}.__spec__ is None') return spec @@ -108,115 +105,64 @@ def find_spec(name, package=None): try: spec = module.__spec__ except AttributeError: - raise ValueError('{}.__spec__ is not set'.format(name)) from None + raise ValueError(f'{name}.__spec__ is not set') from None else: if spec is None: - raise ValueError('{}.__spec__ is None'.format(name)) + raise ValueError(f'{name}.__spec__ is None') return spec -@contextmanager -def _module_to_load(name): - is_reload = name in sys.modules - - module = sys.modules.get(name) - if not is_reload: - # This must be done before open() is called as the 'io' module - # implicitly imports 'locale' and would otherwise trigger an - # infinite loop. - module = type(sys)(name) - # This must be done before putting the module in sys.modules - # (otherwise an optimization shortcut in import.c becomes wrong) - module.__initializing__ = True - sys.modules[name] = module - try: - yield module - except Exception: - if not is_reload: - try: - del sys.modules[name] - except KeyError: - pass - finally: - module.__initializing__ = False +# Normally we would use contextlib.contextmanager. However, this module +# is imported by runpy, which means we want to avoid any unnecessary +# dependencies. Thus we use a class. +class _incompatible_extension_module_restrictions: + """A context manager that can temporarily skip the compatibility check. -def set_package(fxn): - """Set __package__ on the returned module. + NOTE: This function is meant to accommodate an unusual case; one + which is likely to eventually go away. There's is a pretty good + chance this is not what you were looking for. - This function is deprecated. + WARNING: Using this function to disable the check can lead to + unexpected behavior and even crashes. It should only be used during + extension module development. - """ - @functools.wraps(fxn) - def set_package_wrapper(*args, **kwargs): - warnings.warn('The import system now takes care of this automatically; ' - 'this decorator is slated for removal in Python 3.12', - DeprecationWarning, stacklevel=2) - module = fxn(*args, **kwargs) - if getattr(module, '__package__', None) is None: - module.__package__ = module.__name__ - if not hasattr(module, '__path__'): - module.__package__ = module.__package__.rpartition('.')[0] - return module - return set_package_wrapper + If "disable_check" is True then the compatibility check will not + happen while the context manager is active. Otherwise the check + *will* happen. + Normally, extensions that do not support multiple interpreters + may not be imported in a subinterpreter. That implies modules + that do not implement multi-phase init or that explicitly of out. -def set_loader(fxn): - """Set __loader__ on the returned module. + Likewise for modules import in a subinterpeter with its own GIL + when the extension does not support a per-interpreter GIL. This + implies the module does not have a Py_mod_multiple_interpreters slot + set to Py_MOD_PER_INTERPRETER_GIL_SUPPORTED. - This function is deprecated. + In both cases, this context manager may be used to temporarily + disable the check for compatible extension modules. + You can get the same effect as this function by implementing the + basic interface of multi-phase init (PEP 489) and lying about + support for mulitple interpreters (or per-interpreter GIL). """ - @functools.wraps(fxn) - def set_loader_wrapper(self, *args, **kwargs): - warnings.warn('The import system now takes care of this automatically; ' - 'this decorator is slated for removal in Python 3.12', - DeprecationWarning, stacklevel=2) - module = fxn(self, *args, **kwargs) - if getattr(module, '__loader__', None) is None: - module.__loader__ = self - return module - return set_loader_wrapper - - -def module_for_loader(fxn): - """Decorator to handle selecting the proper module for loaders. - - The decorated function is passed the module to use instead of the module - name. The module passed in to the function is either from sys.modules if - it already exists or is a new module. If the module is new, then __name__ - is set the first argument to the method, __loader__ is set to self, and - __package__ is set accordingly (if self.is_package() is defined) will be set - before it is passed to the decorated function (if self.is_package() does - not work for the module it will be set post-load). - - If an exception is raised and the decorator created the module it is - subsequently removed from sys.modules. - - The decorator assumes that the decorated function takes the module name as - the second argument. - """ - warnings.warn('The import system now takes care of this automatically; ' - 'this decorator is slated for removal in Python 3.12', - DeprecationWarning, stacklevel=2) - @functools.wraps(fxn) - def module_for_loader_wrapper(self, fullname, *args, **kwargs): - with _module_to_load(fullname) as module: - module.__loader__ = self - try: - is_package = self.is_package(fullname) - except (ImportError, AttributeError): - pass - else: - if is_package: - module.__package__ = fullname - else: - module.__package__ = fullname.rpartition('.')[0] - # If __package__ was not set above, __import__() will do it later. - return fxn(self, module, *args, **kwargs) - - return module_for_loader_wrapper + def __init__(self, *, disable_check): + self.disable_check = bool(disable_check) + + def __enter__(self): + self.old = _imp._override_multi_interp_extensions_check(self.override) + return self + + def __exit__(self, *args): + old = self.old + del self.old + _imp._override_multi_interp_extensions_check(old) + + @property + def override(self): + return -1 if self.disable_check else 1 class _LazyModule(types.ModuleType): From 69bd951dbdea7a4550253fd59aa247add058849c Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Tue, 3 Oct 2023 23:51:58 +0900 Subject: [PATCH 03/19] Update test_importlib from Python3.12 --- Lib/test/test_importlib/_context.py | 13 + Lib/test/test_importlib/_path.py | 109 ++++++ .../test_importlib/builtin/test_finder.py | 50 --- .../extension/test_case_sensitivity.py | 2 +- .../test_importlib/extension/test_loader.py | 151 ++++++-- .../extension/test_path_hook.py | 2 +- Lib/test/test_importlib/fixtures.py | 158 +++++--- Lib/test/test_importlib/frozen/test_finder.py | 48 --- Lib/test/test_importlib/frozen/test_loader.py | 105 +---- .../test_importlib/import_/test___loader__.py | 43 --- .../import_/test___package__.py | 40 +- Lib/test/test_importlib/import_/test_api.py | 5 - .../test_importlib/import_/test_caching.py | 25 +- .../test_importlib/import_/test_helpers.py | 184 +++++++++ .../test_importlib/import_/test_meta_path.py | 10 - Lib/test/test_importlib/import_/test_path.py | 73 +--- Lib/test/test_importlib/resources/_path.py | 56 +++ .../{ => resources}/data01/__init__.py | 0 .../resources/data01/binary.file | Bin 0 -> 4 bytes .../data01/subdirectory/__init__.py | 0 .../resources/data01/subdirectory/binary.file | Bin 0 -> 4 bytes .../resources/data01/utf-16.file | Bin 0 -> 44 bytes .../{ => resources}/data01/utf-8.file | 0 .../{ => resources}/data02/__init__.py | 0 .../{ => resources}/data02/one/__init__.py | 0 .../{ => resources}/data02/one/resource1.txt | 0 .../subdirectory/subsubdir/resource.txt | 1 + .../{ => resources}/data02/two/__init__.py | 0 .../{ => resources}/data02/two/resource2.txt | 0 .../{ => resources}/data03/__init__.py | 0 .../data03/namespace/portion1/__init__.py | 0 .../data03/namespace/portion2/__init__.py | 0 .../data03/namespace/resource1.txt | 0 .../resources/namespacedata01/binary.file | Bin 0 -> 4 bytes .../resources/namespacedata01/utf-16.file | Bin 0 -> 44 bytes .../namespacedata01/utf-8.file | 0 .../test_compatibilty_files.py | 8 +- .../{ => resources}/test_contents.py | 2 +- .../test_importlib/resources/test_custom.py | 46 +++ .../test_importlib/resources/test_files.py | 113 ++++++ .../{ => resources}/test_open.py | 22 +- .../{ => resources}/test_path.py | 17 +- .../{ => resources}/test_read.py | 14 +- .../{ => resources}/test_reader.py | 16 + .../{ => resources}/test_resource.py | 110 +++--- .../{ => resources}/update-zips.py | 0 Lib/test/test_importlib/resources/util.py | 51 +-- .../{ => resources}/zipdata01/__init__.py | 0 .../resources/zipdata01/ziptestdata.zip | Bin 0 -> 876 bytes .../{ => resources}/zipdata02/__init__.py | 0 .../resources/zipdata02/ziptestdata.zip | Bin 0 -> 698 bytes .../source/test_case_sensitivity.py | 13 - .../test_importlib/source/test_file_loader.py | 1 - Lib/test/test_importlib/source/test_finder.py | 29 +- .../test_importlib/source/test_path_hook.py | 9 - Lib/test/test_importlib/test_abc.py | 118 +----- Lib/test/test_importlib/test_api.py | 71 ++-- Lib/test/test_importlib/test_files.py | 46 --- Lib/test/test_importlib/test_locks.py | 5 + Lib/test/test_importlib/test_main.py | 119 +++++- Lib/test/test_importlib/test_metadata_api.py | 112 +++--- .../test_importlib/test_namespace_pkgs.py | 7 +- Lib/test/test_importlib/test_spec.py | 143 ------- .../test_importlib/test_threaded_import.py | 11 +- Lib/test/test_importlib/test_util.py | 364 ++++++------------ Lib/test/test_importlib/test_windows.py | 18 +- Lib/test/test_importlib/util.py | 34 +- 67 files changed, 1236 insertions(+), 1338 deletions(-) create mode 100644 Lib/test/test_importlib/_context.py create mode 100644 Lib/test/test_importlib/_path.py create mode 100644 Lib/test/test_importlib/import_/test_helpers.py create mode 100644 Lib/test/test_importlib/resources/_path.py rename Lib/test/test_importlib/{ => resources}/data01/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/data01/binary.file rename Lib/test/test_importlib/{ => resources}/data01/subdirectory/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/data01/subdirectory/binary.file create mode 100644 Lib/test/test_importlib/resources/data01/utf-16.file rename Lib/test/test_importlib/{ => resources}/data01/utf-8.file (100%) rename Lib/test/test_importlib/{ => resources}/data02/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data02/one/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data02/one/resource1.txt (100%) create mode 100644 Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt rename Lib/test/test_importlib/{ => resources}/data02/two/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data02/two/resource2.txt (100%) rename Lib/test/test_importlib/{ => resources}/data03/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data03/namespace/portion1/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data03/namespace/portion2/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data03/namespace/resource1.txt (100%) create mode 100644 Lib/test/test_importlib/resources/namespacedata01/binary.file create mode 100644 Lib/test/test_importlib/resources/namespacedata01/utf-16.file rename Lib/test/test_importlib/{ => resources}/namespacedata01/utf-8.file (100%) rename Lib/test/test_importlib/{ => resources}/test_compatibilty_files.py (93%) rename Lib/test/test_importlib/{ => resources}/test_contents.py (97%) create mode 100644 Lib/test/test_importlib/resources/test_custom.py create mode 100644 Lib/test/test_importlib/resources/test_files.py rename Lib/test/test_importlib/{ => resources}/test_open.py (82%) rename Lib/test/test_importlib/{ => resources}/test_path.py (84%) rename Lib/test/test_importlib/{ => resources}/test_read.py (86%) rename Lib/test/test_importlib/{ => resources}/test_reader.py (85%) rename Lib/test/test_importlib/{ => resources}/test_resource.py (74%) rename Lib/test/test_importlib/{ => resources}/update-zips.py (100%) rename Lib/test/test_importlib/{ => resources}/zipdata01/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/zipdata01/ziptestdata.zip rename Lib/test/test_importlib/{ => resources}/zipdata02/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/zipdata02/ziptestdata.zip delete mode 100644 Lib/test/test_importlib/test_files.py diff --git a/Lib/test/test_importlib/_context.py b/Lib/test/test_importlib/_context.py new file mode 100644 index 0000000000..8a53eb55d1 --- /dev/null +++ b/Lib/test/test_importlib/_context.py @@ -0,0 +1,13 @@ +import contextlib + + +# from jaraco.context 4.3 +class suppress(contextlib.suppress, contextlib.ContextDecorator): + """ + A version of contextlib.suppress with decorator support. + + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ diff --git a/Lib/test/test_importlib/_path.py b/Lib/test/test_importlib/_path.py new file mode 100644 index 0000000000..71a704389b --- /dev/null +++ b/Lib/test/test_importlib/_path.py @@ -0,0 +1,109 @@ +# from jaraco.path 3.5 + +import functools +import pathlib +from typing import Dict, Union + +try: + from typing import Protocol, runtime_checkable +except ImportError: # pragma: no cover + # Python 3.7 + from typing_extensions import Protocol, runtime_checkable # type: ignore + + +FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore + + +@runtime_checkable +class TreeMaker(Protocol): + def __truediv__(self, *args, **kwargs): + ... # pragma: no cover + + def mkdir(self, **kwargs): + ... # pragma: no cover + + def write_text(self, content, **kwargs): + ... # pragma: no cover + + def write_bytes(self, content): + ... # pragma: no cover + + +def _ensure_tree_maker(obj: Union[str, TreeMaker]) -> TreeMaker: + return obj if isinstance(obj, TreeMaker) else pathlib.Path(obj) # type: ignore + + +def build( + spec: FilesSpec, + prefix: Union[str, TreeMaker] = pathlib.Path(), # type: ignore +): + """ + Build a set of files/directories, as described by the spec. + + Each key represents a pathname, and the value represents + the content. Content may be a nested directory. + + >>> spec = { + ... 'README.txt': "A README file", + ... "foo": { + ... "__init__.py": "", + ... "bar": { + ... "__init__.py": "", + ... }, + ... "baz.py": "# Some code", + ... } + ... } + >>> target = getfixture('tmp_path') + >>> build(spec, target) + >>> target.joinpath('foo/baz.py').read_text(encoding='utf-8') + '# Some code' + """ + for name, contents in spec.items(): + create(contents, _ensure_tree_maker(prefix) / name) + + +@functools.singledispatch +def create(content: Union[str, bytes, FilesSpec], path): + path.mkdir(exist_ok=True) + build(content, prefix=path) # type: ignore + + +@create.register +def _(content: bytes, path): + path.write_bytes(content) + + +@create.register +def _(content: str, path): + path.write_text(content, encoding='utf-8') + + +@create.register +def _(content: str, path): + path.write_text(content, encoding='utf-8') + + +class Recording: + """ + A TreeMaker object that records everything that would be written. + + >>> r = Recording() + >>> build({'foo': {'foo1.txt': 'yes'}, 'bar.txt': 'abc'}, r) + >>> r.record + ['foo/foo1.txt', 'bar.txt'] + """ + + def __init__(self, loc=pathlib.PurePosixPath(), record=None): + self.loc = loc + self.record = record if record is not None else [] + + def __truediv__(self, other): + return Recording(self.loc / other, self.record) + + def write_text(self, content, **kwargs): + self.record.append(str(self.loc)) + + write_bytes = write_text + + def mkdir(self, **kwargs): + return diff --git a/Lib/test/test_importlib/builtin/test_finder.py b/Lib/test/test_importlib/builtin/test_finder.py index a4869e07b9..111c4af1ea 100644 --- a/Lib/test/test_importlib/builtin/test_finder.py +++ b/Lib/test/test_importlib/builtin/test_finder.py @@ -37,61 +37,11 @@ def test_failure(self): spec = self.machinery.BuiltinImporter.find_spec(name) self.assertIsNone(spec) - def test_ignore_path(self): - # The value for 'path' should always trigger a failed import. - with util.uncache(util.BUILTINS.good_name): - spec = self.machinery.BuiltinImporter.find_spec(util.BUILTINS.good_name, - ['pkg']) - self.assertIsNone(spec) - (Frozen_FindSpecTests, Source_FindSpecTests ) = util.test_both(FindSpecTests, machinery=machinery) -@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') -class FinderTests(abc.FinderTests): - - """Test find_module() for built-in modules.""" - - def test_module(self): - # Common case. - with util.uncache(util.BUILTINS.good_name): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - found = self.machinery.BuiltinImporter.find_module(util.BUILTINS.good_name) - self.assertTrue(found) - self.assertTrue(hasattr(found, 'load_module')) - - # Built-in modules cannot be a package. - test_package = test_package_in_package = test_package_over_module = None - - # Built-in modules cannot be in a package. - test_module_in_package = None - - def test_failure(self): - assert 'importlib' not in sys.builtin_module_names - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.BuiltinImporter.find_module('importlib') - self.assertIsNone(loader) - - def test_ignore_path(self): - # The value for 'path' should always trigger a failed import. - with util.uncache(util.BUILTINS.good_name): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.BuiltinImporter.find_module( - util.BUILTINS.good_name, - ['pkg']) - self.assertIsNone(loader) - - -(Frozen_FinderTests, - Source_FinderTests - ) = util.test_both(FinderTests, machinery=machinery) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/extension/test_case_sensitivity.py b/Lib/test/test_importlib/extension/test_case_sensitivity.py index 366e565cf4..0bb74fff5f 100644 --- a/Lib/test/test_importlib/extension/test_case_sensitivity.py +++ b/Lib/test/test_importlib/extension/test_case_sensitivity.py @@ -8,7 +8,7 @@ machinery = util.import_importlib('importlib.machinery') -@unittest.skipIf(util.EXTENSIONS.filename is None, '_testcapi not available') +@unittest.skipIf(util.EXTENSIONS.filename is None, f'{util.EXTENSIONS.name} not available') @util.case_insensitive_tests class ExtensionModuleCaseSensitivityTest(util.CASEOKTestBase): diff --git a/Lib/test/test_importlib/extension/test_loader.py b/Lib/test/test_importlib/extension/test_loader.py index 6c5cd577c1..d06558f2ad 100644 --- a/Lib/test/test_importlib/extension/test_loader.py +++ b/Lib/test/test_importlib/extension/test_loader.py @@ -13,9 +13,9 @@ from test.support.script_helper import assert_python_failure -class LoaderTests(abc.LoaderTests): +class LoaderTests: - """Test load_module() for extension modules.""" + """Test ExtensionFileLoader.""" def setUp(self): if not self.machinery.EXTENSION_SUFFIXES: @@ -32,17 +32,6 @@ def load_module(self, fullname): warnings.simplefilter("ignore", DeprecationWarning) return self.loader.load_module(fullname) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_load_module_API(self): - # Test the default argument for load_module(). - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - self.loader.load_module() - self.loader.load_module(None) - with self.assertRaises(ImportError): - self.load_module('XXX') - def test_equality(self): other = self.machinery.ExtensionFileLoader(util.EXTENSIONS.name, util.EXTENSIONS.file_path) @@ -53,6 +42,15 @@ def test_inequality(self): util.EXTENSIONS.file_path) self.assertNotEqual(self.loader, other) + def test_load_module_API(self): + # Test the default argument for load_module(). + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + self.loader.load_module() + self.loader.load_module(None) + with self.assertRaises(ImportError): + self.load_module('XXX') + # TODO: RUSTPYTHON @unittest.expectedFailure def test_module(self): @@ -72,14 +70,6 @@ def test_module(self): # No extension module in a package available for testing. test_lacking_parent = None - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_module_reuse(self): - with util.uncache(util.EXTENSIONS.name): - module1 = self.load_module(util.EXTENSIONS.name) - module2 = self.load_module(util.EXTENSIONS.name) - self.assertIs(module1, module2) - # No easy way to trigger a failure after a successful import. test_state_after_failure = None @@ -89,6 +79,12 @@ def test_unloadable(self): self.load_module(name) self.assertEqual(cm.exception.name, name) + def test_module_reuse(self): + with util.uncache(util.EXTENSIONS.name): + module1 = self.load_module(util.EXTENSIONS.name) + module2 = self.load_module(util.EXTENSIONS.name) + self.assertIs(module1, module2) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_is_package(self): @@ -98,11 +94,94 @@ def test_is_package(self): loader = self.machinery.ExtensionFileLoader('pkg', path) self.assertTrue(loader.is_package('pkg')) + (Frozen_LoaderTests, Source_LoaderTests ) = util.test_both(LoaderTests, machinery=machinery) -@unittest.skip("TODO: RUSTPYTHON, AssertionError") + +class SinglePhaseExtensionModuleTests(abc.LoaderTests): + # Test loading extension modules without multi-phase initialization. + + def setUp(self): + if not self.machinery.EXTENSION_SUFFIXES: + raise unittest.SkipTest("Requires dynamic loading support.") + self.name = '_testsinglephase' + if self.name in sys.builtin_module_names: + raise unittest.SkipTest( + f"{self.name} is a builtin module" + ) + finder = self.machinery.FileFinder(None) + self.spec = importlib.util.find_spec(self.name) + assert self.spec + self.loader = self.machinery.ExtensionFileLoader( + self.name, self.spec.origin) + + def load_module(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return self.loader.load_module(self.name) + + def load_module_by_name(self, fullname): + # Load a module from the test extension by name. + origin = self.spec.origin + loader = self.machinery.ExtensionFileLoader(fullname, origin) + spec = importlib.util.spec_from_loader(fullname, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + return module + + def test_module(self): + # Test loading an extension module. + with util.uncache(self.name): + module = self.load_module() + for attr, value in [('__name__', self.name), + ('__file__', self.spec.origin), + ('__package__', '')]: + self.assertEqual(getattr(module, attr), value) + with self.assertRaises(AttributeError): + module.__path__ + self.assertIs(module, sys.modules[self.name]) + self.assertIsInstance(module.__loader__, + self.machinery.ExtensionFileLoader) + + # No extension module as __init__ available for testing. + test_package = None + + # No extension module in a package available for testing. + test_lacking_parent = None + + # No easy way to trigger a failure after a successful import. + test_state_after_failure = None + + def test_unloadable(self): + name = 'asdfjkl;' + with self.assertRaises(ImportError) as cm: + self.load_module_by_name(name) + self.assertEqual(cm.exception.name, name) + + def test_unloadable_nonascii(self): + # Test behavior with nonexistent module with non-ASCII name. + name = 'fo\xf3' + with self.assertRaises(ImportError) as cm: + self.load_module_by_name(name) + self.assertEqual(cm.exception.name, name) + + # It may make sense to add the equivalent to + # the following MultiPhaseExtensionModuleTests tests: + # + # * test_nonmodule + # * test_nonmodule_with_methods + # * test_bad_modules + # * test_nonascii + + +(Frozen_SinglePhaseExtensionModuleTests, + Source_SinglePhaseExtensionModuleTests + ) = util.test_both(SinglePhaseExtensionModuleTests, machinery=machinery) + + +# @unittest.skip("TODO: RUSTPYTHON, AssertionError") class MultiPhaseExtensionModuleTests(abc.LoaderTests): # Test loading extension modules with multi-phase initialization (PEP 489). @@ -188,15 +267,16 @@ def test_reload(self): def test_try_registration(self): # Assert that the PyState_{Find,Add,Remove}Module C API doesn't work. - module = self.load_module() - with self.subTest('PyState_FindModule'): - self.assertEqual(module.call_state_registration_func(0), None) - with self.subTest('PyState_AddModule'): - with self.assertRaises(SystemError): - module.call_state_registration_func(1) - with self.subTest('PyState_RemoveModule'): - with self.assertRaises(SystemError): - module.call_state_registration_func(2) + with util.uncache(self.name): + module = self.load_module() + with self.subTest('PyState_FindModule'): + self.assertEqual(module.call_state_registration_func(0), None) + with self.subTest('PyState_AddModule'): + with self.assertRaises(SystemError): + module.call_state_registration_func(1) + with self.subTest('PyState_RemoveModule'): + with self.assertRaises(SystemError): + module.call_state_registration_func(2) def test_load_submodule(self): # Test loading a simulated submodule. @@ -274,12 +354,19 @@ def test_bad_modules(self): 'exec_err', 'exec_raise', 'exec_unreported_exception', + 'multiple_create_slots', + 'multiple_multiple_interpreters_slots', ]: with self.subTest(name_base): name = self.name + '_' + name_base - with self.assertRaises(SystemError): + with self.assertRaises(SystemError) as cm: self.load_module_by_name(name) + # If there is an unreported exception, it should be chained + # with the `SystemError`. + if "unreported_exception" in name_base: + self.assertIsNotNone(cm.exception.__cause__) + def test_nonascii(self): # Test that modules with non-ASCII names can be loaded. # punycode behaves slightly differently in some-ASCII and no-ASCII diff --git a/Lib/test/test_importlib/extension/test_path_hook.py b/Lib/test/test_importlib/extension/test_path_hook.py index a0adc70ad1..ec9644dc52 100644 --- a/Lib/test/test_importlib/extension/test_path_hook.py +++ b/Lib/test/test_importlib/extension/test_path_hook.py @@ -19,7 +19,7 @@ def hook(self, entry): def test_success(self): # Path hook should handle a directory where a known extension module # exists. - self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_module')) + self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_spec')) (Frozen_PathHooksTests, diff --git a/Lib/test/test_importlib/fixtures.py b/Lib/test/test_importlib/fixtures.py index e7be77b395..73e5da2ba9 100644 --- a/Lib/test/test_importlib/fixtures.py +++ b/Lib/test/test_importlib/fixtures.py @@ -10,7 +10,10 @@ from test.support.os_helper import FS_NONASCII from test.support import requires_zlib -from typing import Dict, Union + +from . import _path +from ._path import FilesSpec + try: from importlib import resources # type: ignore @@ -83,13 +86,8 @@ def setUp(self): self.fixtures.enter_context(self.add_sys_path(self.site_dir)) -# Except for python/mypy#731, prefer to define -# FilesDef = Dict[str, Union['FilesDef', str]] -FilesDef = Dict[str, Union[Dict[str, Union[Dict[str, str], str]], str]] - - class DistInfoPkg(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "distinfo_pkg-1.0.0.dist-info": { "METADATA": """ Name: distinfo-pkg @@ -131,7 +129,7 @@ def make_uppercase(self): class DistInfoPkgWithDot(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "pkg_dot-1.0.0.dist-info": { "METADATA": """ Name: pkg.dot @@ -146,7 +144,7 @@ def setUp(self): class DistInfoPkgWithDotLegacy(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "pkg.dot-1.0.0.dist-info": { "METADATA": """ Name: pkg.dot @@ -173,7 +171,7 @@ def setUp(self): class EggInfoPkg(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "egginfo_pkg.egg-info": { "PKG-INFO": """ Name: egginfo-pkg @@ -212,8 +210,99 @@ def setUp(self): build_files(EggInfoPkg.files, prefix=self.site_dir) +class EggInfoPkgPipInstalledNoToplevel(OnSysPath, SiteDir): + files: FilesSpec = { + "egg_with_module_pkg.egg-info": { + "PKG-INFO": "Name: egg_with_module-pkg", + # SOURCES.txt is made from the source archive, and contains files + # (setup.py) that are not present after installation. + "SOURCES.txt": """ + egg_with_module.py + setup.py + egg_with_module_pkg.egg-info/PKG-INFO + egg_with_module_pkg.egg-info/SOURCES.txt + egg_with_module_pkg.egg-info/top_level.txt + """, + # installed-files.txt is written by pip, and is a strictly more + # accurate source than SOURCES.txt as to the installed contents of + # the package. + "installed-files.txt": """ + ../egg_with_module.py + PKG-INFO + SOURCES.txt + top_level.txt + """, + # missing top_level.txt (to trigger fallback to installed-files.txt) + }, + "egg_with_module.py": """ + def main(): + print("hello world") + """, + } + + def setUp(self): + super().setUp() + build_files(EggInfoPkgPipInstalledNoToplevel.files, prefix=self.site_dir) + + +class EggInfoPkgPipInstalledNoModules(OnSysPath, SiteDir): + files: FilesSpec = { + "egg_with_no_modules_pkg.egg-info": { + "PKG-INFO": "Name: egg_with_no_modules-pkg", + # SOURCES.txt is made from the source archive, and contains files + # (setup.py) that are not present after installation. + "SOURCES.txt": """ + setup.py + egg_with_no_modules_pkg.egg-info/PKG-INFO + egg_with_no_modules_pkg.egg-info/SOURCES.txt + egg_with_no_modules_pkg.egg-info/top_level.txt + """, + # installed-files.txt is written by pip, and is a strictly more + # accurate source than SOURCES.txt as to the installed contents of + # the package. + "installed-files.txt": """ + PKG-INFO + SOURCES.txt + top_level.txt + """, + # top_level.txt correctly reflects that no modules are installed + "top_level.txt": b"\n", + }, + } + + def setUp(self): + super().setUp() + build_files(EggInfoPkgPipInstalledNoModules.files, prefix=self.site_dir) + + +class EggInfoPkgSourcesFallback(OnSysPath, SiteDir): + files: FilesSpec = { + "sources_fallback_pkg.egg-info": { + "PKG-INFO": "Name: sources_fallback-pkg", + # SOURCES.txt is made from the source archive, and contains files + # (setup.py) that are not present after installation. + "SOURCES.txt": """ + sources_fallback.py + setup.py + sources_fallback_pkg.egg-info/PKG-INFO + sources_fallback_pkg.egg-info/SOURCES.txt + """, + # missing installed-files.txt (i.e. not installed by pip) and + # missing top_level.txt (to trigger fallback to SOURCES.txt) + }, + "sources_fallback.py": """ + def main(): + print("hello world") + """, + } + + def setUp(self): + super().setUp() + build_files(EggInfoPkgSourcesFallback.files, prefix=self.site_dir) + + class EggInfoFile(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "egginfo_file.egg-info": """ Metadata-Version: 1.0 Name: egginfo_file @@ -233,38 +322,22 @@ def setUp(self): build_files(EggInfoFile.files, prefix=self.site_dir) -def build_files(file_defs, prefix=pathlib.Path()): - """Build a set of files/directories, as described by the +# dedent all text strings before writing +orig = _path.create.registry[str] +_path.create.register(str, lambda content, path: orig(DALS(content), path)) - file_defs dictionary. Each key/value pair in the dictionary is - interpreted as a filename/contents pair. If the contents value is a - dictionary, a directory is created, and the dictionary interpreted - as the files within it, recursively. - For example: +build_files = _path.build - {"README.txt": "A README file", - "foo": { - "__init__.py": "", - "bar": { - "__init__.py": "", - }, - "baz.py": "# Some code", - } - } - """ - for name, contents in file_defs.items(): - full_name = prefix / name - if isinstance(contents, dict): - full_name.mkdir() - build_files(contents, prefix=full_name) - else: - if isinstance(contents, bytes): - with full_name.open('wb') as f: - f.write(contents) - else: - with full_name.open('w', encoding='utf-8') as f: - f.write(DALS(contents)) + +def build_record(file_defs): + return ''.join(f'{name},,\n' for name in record_names(file_defs)) + + +def record_names(file_defs): + recording = _path.Recording() + _path.build(file_defs, recording) + return recording.record class FileBuilder: @@ -277,11 +350,6 @@ def DALS(str): return textwrap.dedent(str).lstrip() -class NullFinder: - def find_module(self, name): - pass - - @requires_zlib() class ZipFixtures: root = 'test.test_importlib.data' diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py index a82148f865..5bb075f377 100644 --- a/Lib/test/test_importlib/frozen/test_finder.py +++ b/Lib/test/test_importlib/frozen/test_finder.py @@ -70,14 +70,6 @@ def check_search_locations(self, spec): expected = [os.path.dirname(filename)] self.assertListEqual(spec.submodule_search_locations, expected) - def test_package(self): - spec = self.find('__phello__') - self.assertIsNotNone(spec) - - def test_module_in_package(self): - spec = self.find('__phello__.spam', ['__phello__']) - self.assertIsNotNone(spec) - # TODO: RUSTPYTHON @unittest.expectedFailure def test_module(self): @@ -196,45 +188,5 @@ def test_not_using_frozen(self): ) = util.test_both(FindSpecTests, machinery=machinery) -class FinderTests(abc.FinderTests): - - """Test finding frozen modules.""" - - def find(self, name, path=None): - finder = self.machinery.FrozenImporter - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - with import_helper.frozen_modules(): - return finder.find_module(name, path) - - def test_module(self): - name = '__hello__' - loader = self.find(name) - self.assertTrue(hasattr(loader, 'load_module')) - - def test_package(self): - loader = self.find('__phello__') - self.assertTrue(hasattr(loader, 'load_module')) - - def test_module_in_package(self): - loader = self.find('__phello__.spam', ['__phello__']) - self.assertTrue(hasattr(loader, 'load_module')) - - # No frozen package within another package to test with. - test_package_in_package = None - - # No easy way to test. - test_package_over_module = None - - def test_failure(self): - loader = self.find('') - self.assertIsNone(loader) - - -(Frozen_FinderTests, - Source_FinderTests - ) = util.test_both(FinderTests, machinery=machinery) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py index db256ff0fb..4f1af454b5 100644 --- a/Lib/test/test_importlib/frozen/test_loader.py +++ b/Lib/test/test_importlib/frozen/test_loader.py @@ -103,15 +103,7 @@ def test_lacking_parent(self): expected=value)) self.assertEqual(output, 'Hello world!\n') - def test_module_repr(self): - name = '__hello__' - module, output = self.exec_module(name) - with deprecated(): - repr_str = self.machinery.FrozenImporter.module_repr(module) - self.assertEqual(repr_str, - "") - - def test_module_repr_indirect(self): + def test_module_repr_indirect_through_spec(self): name = '__hello__' module, output = self.exec_module(name) self.assertEqual(repr(module), @@ -133,101 +125,6 @@ def test_unloadable(self): ) = util.test_both(ExecModuleTests, machinery=machinery) -class LoaderTests(abc.LoaderTests): - - def load_module(self, name): - with fresh(name, oldapi=True): - module = self.machinery.FrozenImporter.load_module(name) - with captured_stdout() as stdout: - module.main() - return module, stdout - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_module(self): - module, stdout = self.load_module('__hello__') - filename = resolve_stdlib_file('__hello__') - check = {'__name__': '__hello__', - '__package__': '', - '__loader__': self.machinery.FrozenImporter, - '__file__': filename, - } - for attr, value in check.items(): - self.assertEqual(getattr(module, attr, None), value) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_package(self): - module, stdout = self.load_module('__phello__') - filename = resolve_stdlib_file('__phello__', ispkg=True) - pkgdir = os.path.dirname(filename) - check = {'__name__': '__phello__', - '__package__': '__phello__', - '__path__': [pkgdir], - '__loader__': self.machinery.FrozenImporter, - '__file__': filename, - } - for attr, value in check.items(): - attr_value = getattr(module, attr, None) - self.assertEqual(attr_value, value, - "for __phello__.%s, %r != %r" % - (attr, attr_value, value)) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_lacking_parent(self): - with util.uncache('__phello__'): - module, stdout = self.load_module('__phello__.spam') - filename = resolve_stdlib_file('__phello__.spam') - check = {'__name__': '__phello__.spam', - '__package__': '__phello__', - '__loader__': self.machinery.FrozenImporter, - '__file__': filename, - } - for attr, value in check.items(): - attr_value = getattr(module, attr) - self.assertEqual(attr_value, value, - "for __phello__.spam.%s, %r != %r" % - (attr, attr_value, value)) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - def test_module_reuse(self): - with fresh('__hello__', oldapi=True): - module1 = self.machinery.FrozenImporter.load_module('__hello__') - module2 = self.machinery.FrozenImporter.load_module('__hello__') - with captured_stdout() as stdout: - module1.main() - module2.main() - self.assertIs(module1, module2) - self.assertEqual(stdout.getvalue(), - 'Hello world!\nHello world!\n') - - def test_module_repr(self): - with fresh('__hello__', oldapi=True): - module = self.machinery.FrozenImporter.load_module('__hello__') - repr_str = self.machinery.FrozenImporter.module_repr(module) - self.assertEqual(repr_str, - "") - - # No way to trigger an error in a frozen module. - test_state_after_failure = None - - def test_unloadable(self): - with import_helper.frozen_modules(): - with deprecated(): - assert self.machinery.FrozenImporter.find_module('_not_real') is None - with self.assertRaises(ImportError) as cm: - self.load_module('_not_real') - self.assertEqual(cm.exception.name, '_not_real') - - -(Frozen_LoaderTests, - Source_LoaderTests - ) = util.test_both(LoaderTests, machinery=machinery) - - class InspectLoaderTests: """Tests for the InspectLoader methods for FrozenImporter.""" diff --git a/Lib/test/test_importlib/import_/test___loader__.py b/Lib/test/test_importlib/import_/test___loader__.py index eaf665a6f5..a14163919a 100644 --- a/Lib/test/test_importlib/import_/test___loader__.py +++ b/Lib/test/test_importlib/import_/test___loader__.py @@ -33,48 +33,5 @@ def test___loader__(self): ) = util.test_both(SpecLoaderAttributeTests, __import__=util.__import__) -class LoaderMock: - - def find_module(self, fullname, path=None): - return self - - def load_module(self, fullname): - sys.modules[fullname] = self.module - return self.module - - -class LoaderAttributeTests: - - def test___loader___missing(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - module = types.ModuleType('blah') - try: - del module.__loader__ - except AttributeError: - pass - loader = LoaderMock() - loader.module = module - with util.uncache('blah'), util.import_state(meta_path=[loader]): - module = self.__import__('blah') - self.assertEqual(loader, module.__loader__) - - def test___loader___is_None(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - module = types.ModuleType('blah') - module.__loader__ = None - loader = LoaderMock() - loader.module = module - with util.uncache('blah'), util.import_state(meta_path=[loader]): - returned_module = self.__import__('blah') - self.assertEqual(loader, module.__loader__) - - -(Frozen_Tests, - Source_Tests - ) = util.test_both(LoaderAttributeTests, __import__=util.__import__) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/import_/test___package__.py b/Lib/test/test_importlib/import_/test___package__.py index cc2fa0f459..431faea5b4 100644 --- a/Lib/test/test_importlib/import_/test___package__.py +++ b/Lib/test/test_importlib/import_/test___package__.py @@ -78,8 +78,8 @@ def test_spec_fallback(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_warn_when_package_and_spec_disagree(self): - # Raise an ImportWarning if __package__ != __spec__.parent. - with self.assertWarns(ImportWarning): + # Raise a DeprecationWarning if __package__ != __spec__.parent. + with self.assertWarns(DeprecationWarning): self.import_module({'__package__': 'pkg.fake', '__spec__': FakeSpec('pkg.fakefake')}) @@ -99,25 +99,6 @@ def __init__(self, parent): self.parent = parent -class Using__package__PEP302(Using__package__): - mock_modules = util.mock_modules - - def test_using___package__(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_using___package__() - - def test_spec_fallback(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_spec_fallback() - - -(Frozen_UsingPackagePEP302, - Source_UsingPackagePEP302 - ) = util.test_both(Using__package__PEP302, __import__=util.__import__) - - class Using__package__PEP451(Using__package__): mock_modules = util.mock_spec @@ -166,23 +147,6 @@ def test_submodule(self): module = getattr(pkg, 'mod') self.assertEqual(module.__package__, 'pkg') -class Setting__package__PEP302(Setting__package__, unittest.TestCase): - mock_modules = util.mock_modules - - def test_top_level(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_top_level() - - def test_package(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_package() - - def test_submodule(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_submodule() class Setting__package__PEP451(Setting__package__, unittest.TestCase): mock_modules = util.mock_spec diff --git a/Lib/test/test_importlib/import_/test_api.py b/Lib/test/test_importlib/import_/test_api.py index 0ee032b020..d6ad590b3d 100644 --- a/Lib/test/test_importlib/import_/test_api.py +++ b/Lib/test/test_importlib/import_/test_api.py @@ -28,11 +28,6 @@ def exec_module(module): class BadLoaderFinder: - @classmethod - def find_module(cls, fullname, path): - if fullname == SUBMOD_NAME: - return cls - @classmethod def load_module(cls, fullname): if fullname == SUBMOD_NAME: diff --git a/Lib/test/test_importlib/import_/test_caching.py b/Lib/test/test_importlib/import_/test_caching.py index 3ca765fb4a..aedf0fd4f9 100644 --- a/Lib/test/test_importlib/import_/test_caching.py +++ b/Lib/test/test_importlib/import_/test_caching.py @@ -52,12 +52,11 @@ class ImportlibUseCache(UseCache, unittest.TestCase): __import__ = util.__import__['Source'] def create_mock(self, *names, return_=None): - mock = util.mock_modules(*names) - original_load = mock.load_module - def load_module(self, fullname): - original_load(fullname) - return return_ - mock.load_module = MethodType(load_module, mock) + mock = util.mock_spec(*names) + original_spec = mock.find_spec + def find_spec(self, fullname, path, target=None): + return original_spec(fullname) + mock.find_spec = MethodType(find_spec, mock) return mock # __import__ inconsistent between loaders and built-in import when it comes @@ -86,14 +85,12 @@ def test_using_cache_for_assigning_to_attribute(self): # See test_using_cache_after_loader() for reasoning. def test_using_cache_for_fromlist(self): # [from cache for fromlist] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - with self.create_mock('pkg.__init__', 'pkg.module') as importer: - with util.import_state(meta_path=[importer]): - module = self.__import__('pkg', fromlist=['module']) - self.assertTrue(hasattr(module, 'module')) - self.assertEqual(id(module.module), - id(sys.modules['pkg.module'])) + with self.create_mock('pkg.__init__', 'pkg.module') as importer: + with util.import_state(meta_path=[importer]): + module = self.__import__('pkg', fromlist=['module']) + self.assertTrue(hasattr(module, 'module')) + self.assertEqual(id(module.module), + id(sys.modules['pkg.module'])) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_helpers.py b/Lib/test/test_importlib/import_/test_helpers.py new file mode 100644 index 0000000000..550f88d1d7 --- /dev/null +++ b/Lib/test/test_importlib/import_/test_helpers.py @@ -0,0 +1,184 @@ +"""Tests for helper functions used by import.c .""" + +from importlib import _bootstrap_external, machinery +import os.path +from types import ModuleType, SimpleNamespace +import unittest +import warnings + +from .. import util + + +class FixUpModuleTests: + + def test_no_loader_but_spec(self): + loader = object() + name = "hello" + path = "hello.py" + spec = machinery.ModuleSpec(name, loader) + ns = {"__spec__": spec} + _bootstrap_external._fix_up_module(ns, name, path) + + expected = {"__spec__": spec, "__loader__": loader, "__file__": path, + "__cached__": None} + self.assertEqual(ns, expected) + + def test_no_loader_no_spec_but_sourceless(self): + name = "hello" + path = "hello.py" + ns = {} + _bootstrap_external._fix_up_module(ns, name, path, path) + + expected = {"__file__": path, "__cached__": path} + + for key, val in expected.items(): + with self.subTest(f"{key}: {val}"): + self.assertEqual(ns[key], val) + + spec = ns["__spec__"] + self.assertIsInstance(spec, machinery.ModuleSpec) + self.assertEqual(spec.name, name) + self.assertEqual(spec.origin, os.path.abspath(path)) + self.assertEqual(spec.cached, os.path.abspath(path)) + self.assertIsInstance(spec.loader, machinery.SourcelessFileLoader) + self.assertEqual(spec.loader.name, name) + self.assertEqual(spec.loader.path, path) + self.assertEqual(spec.loader, ns["__loader__"]) + + def test_no_loader_no_spec_but_source(self): + name = "hello" + path = "hello.py" + ns = {} + _bootstrap_external._fix_up_module(ns, name, path) + + expected = {"__file__": path, "__cached__": None} + + for key, val in expected.items(): + with self.subTest(f"{key}: {val}"): + self.assertEqual(ns[key], val) + + spec = ns["__spec__"] + self.assertIsInstance(spec, machinery.ModuleSpec) + self.assertEqual(spec.name, name) + self.assertEqual(spec.origin, os.path.abspath(path)) + self.assertIsInstance(spec.loader, machinery.SourceFileLoader) + self.assertEqual(spec.loader.name, name) + self.assertEqual(spec.loader.path, path) + self.assertEqual(spec.loader, ns["__loader__"]) + + +FrozenFixUpModuleTests, SourceFixUpModuleTests = util.test_both(FixUpModuleTests) + + +class TestBlessMyLoader(unittest.TestCase): + # GH#86298 is part of the migration away from module attributes and toward + # __spec__ attributes. There are several cases to test here. This will + # have to change in Python 3.14 when we actually remove/ignore __loader__ + # in favor of requiring __spec__.loader. + + def test_gh86298_no_loader_and_no_spec(self): + bar = ModuleType('bar') + del bar.__loader__ + del bar.__spec__ + # 2022-10-06(warsaw): For backward compatibility with the + # implementation in _warnings.c, this can't raise an + # AttributeError. See _bless_my_loader() in _bootstrap_external.py + # If working with a module: + ## self.assertRaises( + ## AttributeError, _bootstrap_external._bless_my_loader, + ## bar.__dict__) + self.assertIsNone(_bootstrap_external._bless_my_loader(bar.__dict__)) + + def test_gh86298_loader_is_none_and_no_spec(self): + bar = ModuleType('bar') + bar.__loader__ = None + del bar.__spec__ + # 2022-10-06(warsaw): For backward compatibility with the + # implementation in _warnings.c, this can't raise an + # AttributeError. See _bless_my_loader() in _bootstrap_external.py + # If working with a module: + ## self.assertRaises( + ## AttributeError, _bootstrap_external._bless_my_loader, + ## bar.__dict__) + self.assertIsNone(_bootstrap_external._bless_my_loader(bar.__dict__)) + + def test_gh86298_no_loader_and_spec_is_none(self): + bar = ModuleType('bar') + del bar.__loader__ + bar.__spec__ = None + self.assertRaises( + ValueError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_loader_is_none_and_spec_is_none(self): + bar = ModuleType('bar') + bar.__loader__ = None + bar.__spec__ = None + self.assertRaises( + ValueError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_loader_is_none_and_spec_loader_is_none(self): + bar = ModuleType('bar') + bar.__loader__ = None + bar.__spec__ = SimpleNamespace(loader=None) + self.assertRaises( + ValueError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_no_spec(self): + bar = ModuleType('bar') + bar.__loader__ = object() + del bar.__spec__ + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_spec_is_none(self): + bar = ModuleType('bar') + bar.__loader__ = object() + bar.__spec__ = None + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_no_spec_loader(self): + bar = ModuleType('bar') + bar.__loader__ = object() + bar.__spec__ = SimpleNamespace() + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_loader_and_spec_loader_disagree(self): + bar = ModuleType('bar') + bar.__loader__ = object() + bar.__spec__ = SimpleNamespace(loader=object()) + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_no_loader_and_no_spec_loader(self): + bar = ModuleType('bar') + del bar.__loader__ + bar.__spec__ = SimpleNamespace() + self.assertRaises( + AttributeError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_no_loader_with_spec_loader_okay(self): + bar = ModuleType('bar') + del bar.__loader__ + loader = object() + bar.__spec__ = SimpleNamespace(loader=loader) + self.assertEqual( + _bootstrap_external._bless_my_loader(bar.__dict__), + loader) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_meta_path.py b/Lib/test/test_importlib/import_/test_meta_path.py index c52fc57065..26e7b070b9 100644 --- a/Lib/test/test_importlib/import_/test_meta_path.py +++ b/Lib/test/test_importlib/import_/test_meta_path.py @@ -115,16 +115,6 @@ def test_with_path(self): super().test_no_path() -class CallSignaturePEP302(CallSignoreSuppressImportWarning): - mock_modules = util.mock_modules - finder_name = 'find_module' - - -(Frozen_CallSignaturePEP302, - Source_CallSignaturePEP302 - ) = util.test_both(CallSignaturePEP302, __import__=util.__import__) - - class CallSignaturePEP451(CallSignature): mock_modules = util.mock_spec finder_name = 'find_spec' diff --git a/Lib/test/test_importlib/import_/test_path.py b/Lib/test/test_importlib/import_/test_path.py index 3873d9f3ed..9cf3a77cb8 100644 --- a/Lib/test/test_importlib/import_/test_path.py +++ b/Lib/test/test_importlib/import_/test_path.py @@ -118,46 +118,6 @@ def test_None_on_sys_path(self): if email is not missing: sys.modules['email'] = email - def test_finder_with_find_module(self): - class TestFinder: - def find_module(self, fullname): - return self.to_return - failing_finder = TestFinder() - failing_finder.to_return = None - path = 'testing path' - with util.import_state(path_importer_cache={path: failing_finder}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - self.assertIsNone( - self.machinery.PathFinder.find_spec('whatever', [path])) - success_finder = TestFinder() - success_finder.to_return = __loader__ - with util.import_state(path_importer_cache={path: success_finder}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - spec = self.machinery.PathFinder.find_spec('whatever', [path]) - self.assertEqual(spec.loader, __loader__) - - def test_finder_with_find_loader(self): - class TestFinder: - loader = None - portions = [] - def find_loader(self, fullname): - return self.loader, self.portions - path = 'testing path' - with util.import_state(path_importer_cache={path: TestFinder()}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - self.assertIsNone( - self.machinery.PathFinder.find_spec('whatever', [path])) - success_finder = TestFinder() - success_finder.loader = __loader__ - with util.import_state(path_importer_cache={path: success_finder}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - spec = self.machinery.PathFinder.find_spec('whatever', [path]) - self.assertEqual(spec.loader, __loader__) - def test_finder_with_find_spec(self): class TestFinder: spec = None @@ -230,9 +190,9 @@ def invalidate_caches(self): class FindModuleTests(FinderTests): def find(self, *args, **kwargs): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return self.machinery.PathFinder.find_module(*args, **kwargs) + spec = self.machinery.PathFinder.find_spec(*args, **kwargs) + return None if spec is None else spec.loader + def check_found(self, found, importer): self.assertIs(found, importer) @@ -257,16 +217,14 @@ def check_found(self, found, importer): class PathEntryFinderTests: def test_finder_with_failing_find_spec(self): - # PathEntryFinder with find_module() defined should work. - # Issue #20763. class Finder: - path_location = 'test_finder_with_find_module' + path_location = 'test_finder_with_find_spec' def __init__(self, path): if path != self.path_location: raise ImportError @staticmethod - def find_module(fullname): + def find_spec(fullname, target=None): return None @@ -276,27 +234,6 @@ def find_module(fullname): warnings.simplefilter("ignore", ImportWarning) self.machinery.PathFinder.find_spec('importlib') - def test_finder_with_failing_find_module(self): - # PathEntryFinder with find_module() defined should work. - # Issue #20763. - class Finder: - path_location = 'test_finder_with_find_module' - def __init__(self, path): - if path != self.path_location: - raise ImportError - - @staticmethod - def find_module(fullname): - return None - - - with util.import_state(path=[Finder.path_location]+sys.path[:], - path_hooks=[Finder]): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - warnings.simplefilter("ignore", DeprecationWarning) - self.machinery.PathFinder.find_module('importlib') - (Frozen_PEFTests, Source_PEFTests diff --git a/Lib/test/test_importlib/resources/_path.py b/Lib/test/test_importlib/resources/_path.py new file mode 100644 index 0000000000..1f97c96146 --- /dev/null +++ b/Lib/test/test_importlib/resources/_path.py @@ -0,0 +1,56 @@ +import pathlib +import functools + +from typing import Dict, Union + + +#### +# from jaraco.path 3.4.1 + +FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore + + +def build(spec: FilesSpec, prefix=pathlib.Path()): + """ + Build a set of files/directories, as described by the spec. + + Each key represents a pathname, and the value represents + the content. Content may be a nested directory. + + >>> spec = { + ... 'README.txt': "A README file", + ... "foo": { + ... "__init__.py": "", + ... "bar": { + ... "__init__.py": "", + ... }, + ... "baz.py": "# Some code", + ... } + ... } + >>> target = getfixture('tmp_path') + >>> build(spec, target) + >>> target.joinpath('foo/baz.py').read_text(encoding='utf-8') + '# Some code' + """ + for name, contents in spec.items(): + create(contents, pathlib.Path(prefix) / name) + + +@functools.singledispatch +def create(content: Union[str, bytes, FilesSpec], path): + path.mkdir(exist_ok=True) + build(content, prefix=path) # type: ignore + + +@create.register +def _(content: bytes, path): + path.write_bytes(content) + + +@create.register +def _(content: str, path): + path.write_text(content, encoding='utf-8') + + +# end from jaraco.path +#### diff --git a/Lib/test/test_importlib/data01/__init__.py b/Lib/test/test_importlib/resources/data01/__init__.py similarity index 100% rename from Lib/test/test_importlib/data01/__init__.py rename to Lib/test/test_importlib/resources/data01/__init__.py diff --git a/Lib/test/test_importlib/resources/data01/binary.file b/Lib/test/test_importlib/resources/data01/binary.file new file mode 100644 index 0000000000000000000000000000000000000000..eaf36c1daccfdf325514461cd1a2ffbc139b5464 GIT binary patch literal 4 LcmZQzWMT#Y01f~L literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/data01/subdirectory/__init__.py b/Lib/test/test_importlib/resources/data01/subdirectory/__init__.py similarity index 100% rename from Lib/test/test_importlib/data01/subdirectory/__init__.py rename to Lib/test/test_importlib/resources/data01/subdirectory/__init__.py diff --git a/Lib/test/test_importlib/resources/data01/subdirectory/binary.file b/Lib/test/test_importlib/resources/data01/subdirectory/binary.file new file mode 100644 index 0000000000000000000000000000000000000000..eaf36c1daccfdf325514461cd1a2ffbc139b5464 GIT binary patch literal 4 LcmZQzWMT#Y01f~L literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/resources/data01/utf-16.file b/Lib/test/test_importlib/resources/data01/utf-16.file new file mode 100644 index 0000000000000000000000000000000000000000..2cb772295ef4b480a8d83725bd5006a0236d8f68 GIT binary patch literal 44 ucmezW&x0YAAqNQa8FUyF7(y9B7~B|i84MZBfV^^`Xc15@g+Y;liva-T)Ce>H literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/data01/utf-8.file b/Lib/test/test_importlib/resources/data01/utf-8.file similarity index 100% rename from Lib/test/test_importlib/data01/utf-8.file rename to Lib/test/test_importlib/resources/data01/utf-8.file diff --git a/Lib/test/test_importlib/data02/__init__.py b/Lib/test/test_importlib/resources/data02/__init__.py similarity index 100% rename from Lib/test/test_importlib/data02/__init__.py rename to Lib/test/test_importlib/resources/data02/__init__.py diff --git a/Lib/test/test_importlib/data02/one/__init__.py b/Lib/test/test_importlib/resources/data02/one/__init__.py similarity index 100% rename from Lib/test/test_importlib/data02/one/__init__.py rename to Lib/test/test_importlib/resources/data02/one/__init__.py diff --git a/Lib/test/test_importlib/data02/one/resource1.txt b/Lib/test/test_importlib/resources/data02/one/resource1.txt similarity index 100% rename from Lib/test/test_importlib/data02/one/resource1.txt rename to Lib/test/test_importlib/resources/data02/one/resource1.txt diff --git a/Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt b/Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt new file mode 100644 index 0000000000..48f587a2d0 --- /dev/null +++ b/Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt @@ -0,0 +1 @@ +a resource \ No newline at end of file diff --git a/Lib/test/test_importlib/data02/two/__init__.py b/Lib/test/test_importlib/resources/data02/two/__init__.py similarity index 100% rename from Lib/test/test_importlib/data02/two/__init__.py rename to Lib/test/test_importlib/resources/data02/two/__init__.py diff --git a/Lib/test/test_importlib/data02/two/resource2.txt b/Lib/test/test_importlib/resources/data02/two/resource2.txt similarity index 100% rename from Lib/test/test_importlib/data02/two/resource2.txt rename to Lib/test/test_importlib/resources/data02/two/resource2.txt diff --git a/Lib/test/test_importlib/data03/__init__.py b/Lib/test/test_importlib/resources/data03/__init__.py similarity index 100% rename from Lib/test/test_importlib/data03/__init__.py rename to Lib/test/test_importlib/resources/data03/__init__.py diff --git a/Lib/test/test_importlib/data03/namespace/portion1/__init__.py b/Lib/test/test_importlib/resources/data03/namespace/portion1/__init__.py similarity index 100% rename from Lib/test/test_importlib/data03/namespace/portion1/__init__.py rename to Lib/test/test_importlib/resources/data03/namespace/portion1/__init__.py diff --git a/Lib/test/test_importlib/data03/namespace/portion2/__init__.py b/Lib/test/test_importlib/resources/data03/namespace/portion2/__init__.py similarity index 100% rename from Lib/test/test_importlib/data03/namespace/portion2/__init__.py rename to Lib/test/test_importlib/resources/data03/namespace/portion2/__init__.py diff --git a/Lib/test/test_importlib/data03/namespace/resource1.txt b/Lib/test/test_importlib/resources/data03/namespace/resource1.txt similarity index 100% rename from Lib/test/test_importlib/data03/namespace/resource1.txt rename to Lib/test/test_importlib/resources/data03/namespace/resource1.txt diff --git a/Lib/test/test_importlib/resources/namespacedata01/binary.file b/Lib/test/test_importlib/resources/namespacedata01/binary.file new file mode 100644 index 0000000000000000000000000000000000000000..eaf36c1daccfdf325514461cd1a2ffbc139b5464 GIT binary patch literal 4 LcmZQzWMT#Y01f~L literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/resources/namespacedata01/utf-16.file b/Lib/test/test_importlib/resources/namespacedata01/utf-16.file new file mode 100644 index 0000000000000000000000000000000000000000..2cb772295ef4b480a8d83725bd5006a0236d8f68 GIT binary patch literal 44 ucmezW&x0YAAqNQa8FUyF7(y9B7~B|i84MZBfV^^`Xc15@g+Y;liva-T)Ce>H literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/namespacedata01/utf-8.file b/Lib/test/test_importlib/resources/namespacedata01/utf-8.file similarity index 100% rename from Lib/test/test_importlib/namespacedata01/utf-8.file rename to Lib/test/test_importlib/resources/namespacedata01/utf-8.file diff --git a/Lib/test/test_importlib/test_compatibilty_files.py b/Lib/test/test_importlib/resources/test_compatibilty_files.py similarity index 93% rename from Lib/test/test_importlib/test_compatibilty_files.py rename to Lib/test/test_importlib/resources/test_compatibilty_files.py index 9a823f2d93..bcf608d9e2 100644 --- a/Lib/test/test_importlib/test_compatibilty_files.py +++ b/Lib/test/test_importlib/resources/test_compatibilty_files.py @@ -8,7 +8,7 @@ wrap_spec, ) -from .resources import util +from . import util class CompatibilityFilesTests(unittest.TestCase): @@ -64,11 +64,13 @@ def test_orphan_path_name(self): def test_spec_path_open(self): self.assertEqual(self.files.read_bytes(), b'Hello, world!') - self.assertEqual(self.files.read_text(), 'Hello, world!') + self.assertEqual(self.files.read_text(encoding='utf-8'), 'Hello, world!') def test_child_path_open(self): self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!') - self.assertEqual((self.files / 'a').read_text(), 'Hello, world!') + self.assertEqual( + (self.files / 'a').read_text(encoding='utf-8'), 'Hello, world!' + ) def test_orphan_path_open(self): with self.assertRaises(FileNotFoundError): diff --git a/Lib/test/test_importlib/test_contents.py b/Lib/test/test_importlib/resources/test_contents.py similarity index 97% rename from Lib/test/test_importlib/test_contents.py rename to Lib/test/test_importlib/resources/test_contents.py index 3323bf5b5c..1a13f043a8 100644 --- a/Lib/test/test_importlib/test_contents.py +++ b/Lib/test/test_importlib/resources/test_contents.py @@ -2,7 +2,7 @@ from importlib import resources from . import data01 -from .resources import util +from . import util class ContentsTests: diff --git a/Lib/test/test_importlib/resources/test_custom.py b/Lib/test/test_importlib/resources/test_custom.py new file mode 100644 index 0000000000..73127209a2 --- /dev/null +++ b/Lib/test/test_importlib/resources/test_custom.py @@ -0,0 +1,46 @@ +import unittest +import contextlib +import pathlib + +from test.support import os_helper + +from importlib import resources +from importlib.resources.abc import TraversableResources, ResourceReader +from . import util + + +class SimpleLoader: + """ + A simple loader that only implements a resource reader. + """ + + def __init__(self, reader: ResourceReader): + self.reader = reader + + def get_resource_reader(self, package): + return self.reader + + +class MagicResources(TraversableResources): + """ + Magically returns the resources at path. + """ + + def __init__(self, path: pathlib.Path): + self.path = path + + def files(self): + return self.path + + +class CustomTraversableResourcesTests(unittest.TestCase): + def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + + def test_custom_loader(self): + temp_dir = self.fixtures.enter_context(os_helper.temp_dir()) + loader = SimpleLoader(MagicResources(temp_dir)) + pkg = util.create_package_from_loader(loader) + files = resources.files(pkg) + assert files is temp_dir diff --git a/Lib/test/test_importlib/resources/test_files.py b/Lib/test/test_importlib/resources/test_files.py new file mode 100644 index 0000000000..1450cfb310 --- /dev/null +++ b/Lib/test/test_importlib/resources/test_files.py @@ -0,0 +1,113 @@ +import typing +import textwrap +import unittest +import warnings +import importlib +import contextlib + +from importlib import resources +from importlib.resources.abc import Traversable +from . import data01 +from . import util +from . import _path +from test.support import os_helper +from test.support import import_helper + + +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter('default', category=DeprecationWarning) + yield ctx + + +class FilesTests: + def test_read_bytes(self): + files = resources.files(self.data) + actual = files.joinpath('utf-8.file').read_bytes() + assert actual == b'Hello, UTF-8 world!\n' + + def test_read_text(self): + files = resources.files(self.data) + actual = files.joinpath('utf-8.file').read_text(encoding='utf-8') + assert actual == 'Hello, UTF-8 world!\n' + + @unittest.skipUnless( + hasattr(typing, 'runtime_checkable'), + "Only suitable when typing supports runtime_checkable", + ) + def test_traversable(self): + assert isinstance(resources.files(self.data), Traversable) + + def test_old_parameter(self): + """ + Files used to take a 'package' parameter. Make sure anyone + passing by name is still supported. + """ + with suppress_known_deprecation(): + resources.files(package=self.data) + + +class OpenDiskTests(FilesTests, unittest.TestCase): + def setUp(self): + self.data = data01 + + +class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase): + pass + + +class OpenNamespaceTests(FilesTests, unittest.TestCase): + def setUp(self): + from . import namespacedata01 + + self.data = namespacedata01 + + +class SiteDir: + def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + self.site_dir = self.fixtures.enter_context(os_helper.temp_dir()) + self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir)) + self.fixtures.enter_context(import_helper.CleanImport()) + + +class ModulesFilesTests(SiteDir, unittest.TestCase): + def test_module_resources(self): + """ + A module can have resources found adjacent to the module. + """ + spec = { + 'mod.py': '', + 'res.txt': 'resources are the best', + } + _path.build(spec, self.site_dir) + import mod + + actual = resources.files(mod).joinpath('res.txt').read_text(encoding='utf-8') + assert actual == spec['res.txt'] + + +class ImplicitContextFilesTests(SiteDir, unittest.TestCase): + def test_implicit_files(self): + """ + Without any parameter, files() will infer the location as the caller. + """ + spec = { + 'somepkg': { + '__init__.py': textwrap.dedent( + """ + import importlib.resources as res + val = res.files().joinpath('res.txt').read_text(encoding='utf-8') + """ + ), + 'res.txt': 'resources are the best', + }, + } + _path.build(spec, self.site_dir) + assert importlib.import_module('somepkg').val == 'resources are the best' + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_importlib/test_open.py b/Lib/test/test_importlib/resources/test_open.py similarity index 82% rename from Lib/test/test_importlib/test_open.py rename to Lib/test/test_importlib/resources/test_open.py index fc0136e865..86becb4bfa 100644 --- a/Lib/test/test_importlib/test_open.py +++ b/Lib/test/test_importlib/resources/test_open.py @@ -2,7 +2,7 @@ from importlib import resources from . import data01 -from .resources import util +from . import util class CommonBinaryTests(util.CommonTests, unittest.TestCase): @@ -15,7 +15,7 @@ def execute(self, package, path): class CommonTextTests(util.CommonTests, unittest.TestCase): def execute(self, package, path): target = resources.files(package).joinpath(path) - with target.open(): + with target.open(encoding='utf-8'): pass @@ -28,7 +28,7 @@ def test_open_binary(self): def test_open_text_default_encoding(self): target = resources.files(self.data) / 'utf-8.file' - with target.open() as fp: + with target.open(encoding='utf-8') as fp: result = fp.read() self.assertEqual(result, 'Hello, UTF-8 world!\n') @@ -39,7 +39,9 @@ def test_open_text_given_encoding(self): self.assertEqual(result, 'Hello, UTF-16 world!\n') def test_open_text_with_errors(self): - # Raises UnicodeError without the 'errors' argument. + """ + Raises UnicodeError without the 'errors' argument. + """ target = resources.files(self.data) / 'utf-16.file' with target.open(encoding='utf-8', errors='strict') as fp: self.assertRaises(UnicodeError, fp.read) @@ -54,11 +56,13 @@ def test_open_text_with_errors(self): def test_open_binary_FileNotFoundError(self): target = resources.files(self.data) / 'does-not-exist' - self.assertRaises(FileNotFoundError, target.open, 'rb') + with self.assertRaises(FileNotFoundError): + target.open('rb') def test_open_text_FileNotFoundError(self): target = resources.files(self.data) / 'does-not-exist' - self.assertRaises(FileNotFoundError, target.open) + with self.assertRaises(FileNotFoundError): + target.open(encoding='utf-8') class OpenDiskTests(OpenTests, unittest.TestCase): @@ -72,12 +76,6 @@ def setUp(self): self.data = namespacedata01 - # TODO: RUSTPYTHON - import sys - if sys.platform == 'win32': - @unittest.expectedFailure - def test_open_text_default_encoding(self): - super().test_open_text_default_encoding() class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase): pass diff --git a/Lib/test/test_importlib/test_path.py b/Lib/test/test_importlib/resources/test_path.py similarity index 84% rename from Lib/test/test_importlib/test_path.py rename to Lib/test/test_importlib/resources/test_path.py index 6fc41f301d..34a6bdd2d5 100644 --- a/Lib/test/test_importlib/test_path.py +++ b/Lib/test/test_importlib/resources/test_path.py @@ -3,7 +3,7 @@ from importlib import resources from . import data01 -from .resources import util +from . import util class CommonTests(util.CommonTests, unittest.TestCase): @@ -14,9 +14,12 @@ def execute(self, package, path): class PathTests: def test_reading(self): - # Path should be readable. - # Test also implicitly verifies the returned object is a pathlib.Path - # instance. + """ + Path should be readable. + + Test also implicitly verifies the returned object is a pathlib.Path + instance. + """ target = resources.files(self.data) / 'utf-8.file' with resources.as_file(target) as path: self.assertTrue(path.name.endswith("utf-8.file"), repr(path)) @@ -51,8 +54,10 @@ def setUp(self): class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase): def test_remove_in_context_manager(self): - # It is not an error if the file that was temporarily stashed on the - # file system is removed inside the `with` stanza. + """ + It is not an error if the file that was temporarily stashed on the + file system is removed inside the `with` stanza. + """ target = resources.files(self.data) / 'utf-8.file' with resources.as_file(target) as path: path.unlink() diff --git a/Lib/test/test_importlib/test_read.py b/Lib/test/test_importlib/resources/test_read.py similarity index 86% rename from Lib/test/test_importlib/test_read.py rename to Lib/test/test_importlib/resources/test_read.py index ebd7226777..088982681e 100644 --- a/Lib/test/test_importlib/test_read.py +++ b/Lib/test/test_importlib/resources/test_read.py @@ -2,7 +2,7 @@ from importlib import import_module, resources from . import data01 -from .resources import util +from . import util class CommonBinaryTests(util.CommonTests, unittest.TestCase): @@ -12,7 +12,7 @@ def execute(self, package, path): class CommonTextTests(util.CommonTests, unittest.TestCase): def execute(self, package, path): - resources.files(package).joinpath(path).read_text() + resources.files(package).joinpath(path).read_text(encoding='utf-8') class ReadTests: @@ -21,7 +21,11 @@ def test_read_bytes(self): self.assertEqual(result, b'\0\1\2\3') def test_read_text_default_encoding(self): - result = resources.files(self.data).joinpath('utf-8.file').read_text() + result = ( + resources.files(self.data) + .joinpath('utf-8.file') + .read_text(encoding='utf-8') + ) self.assertEqual(result, 'Hello, UTF-8 world!\n') def test_read_text_given_encoding(self): @@ -33,7 +37,9 @@ def test_read_text_given_encoding(self): self.assertEqual(result, 'Hello, UTF-16 world!\n') def test_read_text_with_errors(self): - # Raises UnicodeError without the 'errors' argument. + """ + Raises UnicodeError without the 'errors' argument. + """ target = resources.files(self.data) / 'utf-16.file' self.assertRaises(UnicodeError, target.read_text, encoding='utf-8') result = target.read_text(encoding='utf-8', errors='ignore') diff --git a/Lib/test/test_importlib/test_reader.py b/Lib/test/test_importlib/resources/test_reader.py similarity index 85% rename from Lib/test/test_importlib/test_reader.py rename to Lib/test/test_importlib/resources/test_reader.py index 9d20c976b8..8670f72a33 100644 --- a/Lib/test/test_importlib/test_reader.py +++ b/Lib/test/test_importlib/resources/test_reader.py @@ -75,6 +75,22 @@ def test_join_path(self): str(path.joinpath('imaginary'))[len(prefix) + 1 :], os.path.join('namespacedata01', 'imaginary'), ) + self.assertEqual(path.joinpath(), path) + + def test_join_path_compound(self): + path = MultiplexedPath(self.folder) + assert not path.joinpath('imaginary/foo.py').exists() + + def test_join_path_common_subdir(self): + prefix = os.path.abspath(os.path.join(__file__, '..')) + data01 = os.path.join(prefix, 'data01') + data02 = os.path.join(prefix, 'data02') + path = MultiplexedPath(data01, data02) + self.assertIsInstance(path.joinpath('subdirectory'), MultiplexedPath) + self.assertEqual( + str(path.joinpath('subdirectory', 'subsubdir'))[len(prefix) + 1 :], + os.path.join('data02', 'subdirectory', 'subsubdir'), + ) def test_repr(self): self.assertEqual( diff --git a/Lib/test/test_importlib/test_resource.py b/Lib/test/test_importlib/resources/test_resource.py similarity index 74% rename from Lib/test/test_importlib/test_resource.py rename to Lib/test/test_importlib/resources/test_resource.py index 834b8bd8a2..6f75cf57f0 100644 --- a/Lib/test/test_importlib/test_resource.py +++ b/Lib/test/test_importlib/resources/test_resource.py @@ -1,3 +1,4 @@ +import contextlib import sys import unittest import uuid @@ -5,9 +6,9 @@ from . import data01 from . import zipdata01, zipdata02 -from .resources import util +from . import util from importlib import resources, import_module -from test.support import import_helper +from test.support import import_helper, os_helper from test.support.os_helper import unlink @@ -69,10 +70,12 @@ def test_resource_missing(self): class ResourceCornerCaseTests(unittest.TestCase): def test_package_has_no_reader_fallback(self): - # Test odd ball packages which: + """ + Test odd ball packages which: # 1. Do not have a ResourceReader as a loader # 2. Are not on the file system # 3. Are not in a zip file + """ module = util.create_package( file=data01, path=data01.__file__, contents=['A', 'B', 'C'] ) @@ -111,6 +114,14 @@ def test_submodule_contents_by_name(self): {'__init__.py', 'binary.file'}, ) + def test_as_file_directory(self): + with resources.as_file(resources.files('ziptestdata')) as data: + assert data.name == 'ziptestdata' + assert data.is_dir() + assert data.joinpath('subdirectory').is_dir() + assert len(list(data.iterdir())) + assert not data.parent.exists() + class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase): ZIP_MODULE = zipdata02 # type: ignore @@ -130,82 +141,71 @@ def test_unrelated_contents(self): ) +@contextlib.contextmanager +def zip_on_path(dir): + data_path = pathlib.Path(zipdata01.__file__) + source_zip_path = data_path.parent.joinpath('ziptestdata.zip') + zip_path = pathlib.Path(dir) / f'{uuid.uuid4()}.zip' + zip_path.write_bytes(source_zip_path.read_bytes()) + sys.path.append(str(zip_path)) + import_module('ziptestdata') + + try: + yield + finally: + with contextlib.suppress(ValueError): + sys.path.remove(str(zip_path)) + + with contextlib.suppress(KeyError): + del sys.path_importer_cache[str(zip_path)] + del sys.modules['ziptestdata'] + + with contextlib.suppress(OSError): + unlink(zip_path) + + class DeletingZipsTest(unittest.TestCase): """Having accessed resources in a zip file should not keep an open reference to the zip. """ - ZIP_MODULE = zipdata01 - def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + modules = import_helper.modules_setup() self.addCleanup(import_helper.modules_cleanup, *modules) - data_path = pathlib.Path(self.ZIP_MODULE.__file__) - data_dir = data_path.parent - self.source_zip_path = data_dir / 'ziptestdata.zip' - self.zip_path = pathlib.Path(f'{uuid.uuid4()}.zip').absolute() - self.zip_path.write_bytes(self.source_zip_path.read_bytes()) - sys.path.append(str(self.zip_path)) - self.data = import_module('ziptestdata') - - def tearDown(self): - try: - sys.path.remove(str(self.zip_path)) - except ValueError: - pass - - try: - del sys.path_importer_cache[str(self.zip_path)] - del sys.modules[self.data.__name__] - except KeyError: - pass - - try: - unlink(self.zip_path) - except OSError: - # If the test fails, this will probably fail too - pass + temp_dir = self.fixtures.enter_context(os_helper.temp_dir()) + self.fixtures.enter_context(zip_on_path(temp_dir)) def test_iterdir_does_not_keep_open(self): - c = [item.name for item in resources.files('ziptestdata').iterdir()] - self.zip_path.unlink() - del c + [item.name for item in resources.files('ziptestdata').iterdir()] def test_is_file_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('binary.file').is_file() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('binary.file').is_file() def test_is_file_failure_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('not-present').is_file() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('not-present').is_file() @unittest.skip("Desired but not supported.") def test_as_file_does_not_keep_open(self): # pragma: no cover - c = resources.as_file(resources.files('ziptestdata') / 'binary.file') - self.zip_path.unlink() - del c + resources.as_file(resources.files('ziptestdata') / 'binary.file') def test_entered_path_does_not_keep_open(self): - # This is what certifi does on import to make its bundle - # available for the process duration. - c = resources.as_file( - resources.files('ziptestdata') / 'binary.file' - ).__enter__() - self.zip_path.unlink() - del c + """ + Mimic what certifi does on import to make its bundle + available for the process duration. + """ + resources.as_file(resources.files('ziptestdata') / 'binary.file').__enter__() def test_read_binary_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('binary.file').read_bytes() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('binary.file').read_bytes() def test_read_text_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('utf-8.file').read_text() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('utf-8.file').read_text( + encoding='utf-8' + ) class ResourceFromNamespaceTest01(unittest.TestCase): diff --git a/Lib/test/test_importlib/update-zips.py b/Lib/test/test_importlib/resources/update-zips.py similarity index 100% rename from Lib/test/test_importlib/update-zips.py rename to Lib/test/test_importlib/resources/update-zips.py diff --git a/Lib/test/test_importlib/resources/util.py b/Lib/test/test_importlib/resources/util.py index 11c8aa8080..dbe6ee8147 100644 --- a/Lib/test/test_importlib/resources/util.py +++ b/Lib/test/test_importlib/resources/util.py @@ -3,11 +3,11 @@ import io import sys import types -from pathlib import Path, PurePath +import pathlib -from .. import data01 -from .. import zipdata01 -from importlib.abc import ResourceReader +from . import data01 +from . import zipdata01 +from importlib.resources.abc import ResourceReader from test.support import import_helper @@ -80,43 +80,44 @@ def execute(self, package, path): """ def test_package_name(self): - # Passing in the package name should succeed. + """ + Passing in the package name should succeed. + """ self.execute(data01.__name__, 'utf-8.file') def test_package_object(self): - # Passing in the package itself should succeed. + """ + Passing in the package itself should succeed. + """ self.execute(data01, 'utf-8.file') def test_string_path(self): - # Passing in a string for the path should succeed. + """ + Passing in a string for the path should succeed. + """ path = 'utf-8.file' self.execute(data01, path) def test_pathlib_path(self): - # Passing in a pathlib.PurePath object for the path should succeed. - path = PurePath('utf-8.file') + """ + Passing in a pathlib.PurePath object for the path should succeed. + """ + path = pathlib.PurePath('utf-8.file') self.execute(data01, path) def test_importing_module_as_side_effect(self): - # The anchor package can already be imported. + """ + The anchor package can already be imported. + """ del sys.modules[data01.__name__] self.execute(data01.__name__, 'utf-8.file') - def test_non_package_by_name(self): - # The anchor package cannot be a module. - with self.assertRaises(TypeError): - self.execute(__name__, 'utf-8.file') - - def test_non_package_by_package(self): - # The anchor package cannot be a module. - with self.assertRaises(TypeError): - module = sys.modules['test.test_importlib.resources.util'] - self.execute(module, 'utf-8.file') - def test_missing_path(self): - # Attempting to open or read or request the path for a - # non-existent path should succeed if open_resource - # can return a viable data stream. + """ + Attempting to open or read or request the path for a + non-existent path should succeed if open_resource + can return a viable data stream. + """ bytes_data = io.BytesIO(b'Hello, world!') package = create_package(file=bytes_data, path=FileNotFoundError()) self.execute(package, 'utf-8.file') @@ -144,7 +145,7 @@ class ZipSetupBase: @classmethod def setUpClass(cls): - data_path = Path(cls.ZIP_MODULE.__file__) + data_path = pathlib.Path(cls.ZIP_MODULE.__file__) data_dir = data_path.parent cls._zip_path = str(data_dir / 'ziptestdata.zip') sys.path.append(cls._zip_path) diff --git a/Lib/test/test_importlib/zipdata01/__init__.py b/Lib/test/test_importlib/resources/zipdata01/__init__.py similarity index 100% rename from Lib/test/test_importlib/zipdata01/__init__.py rename to Lib/test/test_importlib/resources/zipdata01/__init__.py diff --git a/Lib/test/test_importlib/resources/zipdata01/ziptestdata.zip b/Lib/test/test_importlib/resources/zipdata01/ziptestdata.zip new file mode 100644 index 0000000000000000000000000000000000000000..9a3bb0739f87e97c1084b94d7d153680f6727738 GIT binary patch literal 876 zcmWIWW@Zs#00HOCX@Q%&m27l?Y!DU);;PJolGNgol*E!m{nC;&T|+ayw9K5;|NlG~ zQWMD z9;rDw`8o=rA#S=B3g!7lIVp-}COK17UPc zNtt;*xhM-3R!jMEPhCreO-3*u>5Df}T7+BJ{639e$2uhfsIs`pJ5Qf}C xGXyDE@VNvOv@o!wQJfLgCAgysx3f@9jKpUmiW^zkK<;1z!tFpk^MROw0RS~O%0&PG literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/zipdata02/__init__.py b/Lib/test/test_importlib/resources/zipdata02/__init__.py similarity index 100% rename from Lib/test/test_importlib/zipdata02/__init__.py rename to Lib/test/test_importlib/resources/zipdata02/__init__.py diff --git a/Lib/test/test_importlib/resources/zipdata02/ziptestdata.zip b/Lib/test/test_importlib/resources/zipdata02/ziptestdata.zip new file mode 100644 index 0000000000000000000000000000000000000000..d63ff512d2807ef2fd259455283b81b02e0e45fb GIT binary patch literal 698 zcmWIWW@Zs#00HOCX@Ot{ln@8fRhb1Psl_EJi6x2p@$s2?nI-Y@dIgmMI5kP5Y0A$_ z#jWw|&p#`9ff_(q7K_HB)Z+ZoqU2OVy^@L&ph*fa0WRVlP*R?c+X1opI-R&20MZDv z&j{oIpa8N17@0(vaR(gGH(;=&5k%n(M%;#g0ulz6G@1gL$cA79E2=^00gEsw4~s!C zUxI@ZWaIMqz|BszK;s4KsL2<9jRy!Q2E6`2cTLHjr{wAk1ZCU@!+_ G1_l6Bc%f?m literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/source/test_case_sensitivity.py b/Lib/test/test_importlib/source/test_case_sensitivity.py index 9d472707ab..6a06313319 100644 --- a/Lib/test/test_importlib/source/test_case_sensitivity.py +++ b/Lib/test/test_importlib/source/test_case_sensitivity.py @@ -63,19 +63,6 @@ def test_insensitive(self): self.assertIn(self.name, insensitive.get_filename(self.name)) -class CaseSensitivityTestPEP302(CaseSensitivityTest): - def find(self, finder): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return finder.find_module(self.name) - - -(Frozen_CaseSensitivityTestPEP302, - Source_CaseSensitivityTestPEP302 - ) = util.test_both(CaseSensitivityTestPEP302, importlib=importlib, - machinery=machinery) - - class CaseSensitivityTestPEP451(CaseSensitivityTest): def find(self, finder): found = finder.find_spec(self.name) diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py index ebf6ec68d7..9c85bd234f 100644 --- a/Lib/test/test_importlib/source/test_file_loader.py +++ b/Lib/test/test_importlib/source/test_file_loader.py @@ -51,7 +51,6 @@ class Tester(self.abc.FileLoader): def get_code(self, _): pass def get_source(self, _): pass def is_package(self, _): pass - def module_repr(self, _): pass path = 'some_path' name = 'some_name' diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py index 3c12ab0123..17d09d4cee 100644 --- a/Lib/test/test_importlib/source/test_finder.py +++ b/Lib/test/test_importlib/source/test_finder.py @@ -120,7 +120,7 @@ def test_package_over_module(self): def test_failure(self): with util.create_modules('blah') as mapping: nothing = self.import_(mapping['.root'], 'sdfsadsadf') - self.assertIsNone(nothing) + self.assertEqual(nothing, self.NOT_FOUND) def test_empty_string_for_dir(self): # The empty string from sys.path means to search in the cwd. @@ -150,7 +150,7 @@ def test_dir_removal_handling(self): found = self._find(finder, 'mod', loader_only=True) self.assertIsNotNone(found) found = self._find(finder, 'mod', loader_only=True) - self.assertIsNone(found) + self.assertEqual(found, self.NOT_FOUND) @unittest.skipUnless(sys.platform != 'win32', 'os.chmod() does not support the needed arguments under Windows') @@ -197,10 +197,12 @@ class FinderTestsPEP420(FinderTests): NOT_FOUND = (None, []) def _find(self, finder, name, loader_only=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader_portions = finder.find_loader(name) - return loader_portions[0] if loader_only else loader_portions + spec = finder.find_spec(name) + if spec is None: + return self.NOT_FOUND + if loader_only: + return spec.loader + return spec.loader, spec.submodule_search_locations (Frozen_FinderTestsPEP420, @@ -208,20 +210,5 @@ def _find(self, finder, name, loader_only=False): ) = util.test_both(FinderTestsPEP420, machinery=machinery) -class FinderTestsPEP302(FinderTests): - - NOT_FOUND = None - - def _find(self, finder, name, loader_only=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return finder.find_module(name) - - -(Frozen_FinderTestsPEP302, - Source_FinderTestsPEP302 - ) = util.test_both(FinderTestsPEP302, machinery=machinery) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/source/test_path_hook.py b/Lib/test/test_importlib/source/test_path_hook.py index ead62f5e94..f274330e0b 100644 --- a/Lib/test/test_importlib/source/test_path_hook.py +++ b/Lib/test/test_importlib/source/test_path_hook.py @@ -18,19 +18,10 @@ def test_success(self): self.assertTrue(hasattr(self.path_hook()(mapping['.root']), 'find_spec')) - def test_success_legacy(self): - with util.create_modules('dummy') as mapping: - self.assertTrue(hasattr(self.path_hook()(mapping['.root']), - 'find_module')) - def test_empty_string(self): # The empty string represents the cwd. self.assertTrue(hasattr(self.path_hook()(''), 'find_spec')) - def test_empty_string_legacy(self): - # The empty string represents the cwd. - self.assertTrue(hasattr(self.path_hook()(''), 'find_module')) - (Frozen_PathHookTest, Source_PathHooktest diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py index d77b8a0a4d..a231ae1d5f 100644 --- a/Lib/test/test_importlib/test_abc.py +++ b/Lib/test/test_importlib/test_abc.py @@ -2,7 +2,6 @@ import marshal import os import sys -from test import support from test.support import import_helper import types import unittest @@ -148,20 +147,13 @@ def ins(self): class MetaPathFinder: - def find_module(self, fullname, path): - return super().find_module(fullname, path) + pass class MetaPathFinderDefaultsTests(ABCTestHarness): SPLIT = make_abc_subclasses(MetaPathFinder) - def test_find_module(self): - # Default should return None. - with self.assertWarns(DeprecationWarning): - found = self.ins.find_module('something', None) - self.assertIsNone(found) - def test_invalidate_caches(self): # Calling the method is a no-op. self.ins.invalidate_caches() @@ -174,22 +166,13 @@ def test_invalidate_caches(self): class PathEntryFinder: - def find_loader(self, fullname): - return super().find_loader(fullname) + pass class PathEntryFinderDefaultsTests(ABCTestHarness): SPLIT = make_abc_subclasses(PathEntryFinder) - def test_find_loader(self): - with self.assertWarns(DeprecationWarning): - found = self.ins.find_loader('something') - self.assertEqual(found, (None, [])) - - def find_module(self): - self.assertEqual(None, self.ins.find_module('something')) - def test_invalidate_caches(self): # Should be a no-op. self.ins.invalidate_caches() @@ -202,8 +185,7 @@ def test_invalidate_caches(self): class Loader: - def load_module(self, fullname): - return super().load_module(fullname) + pass class LoaderDefaultsTests(ABCTestHarness): @@ -222,8 +204,6 @@ def test_module_repr(self): mod = types.ModuleType('blah') with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) - with self.assertRaises(NotImplementedError): - self.ins.module_repr(mod) original_repr = repr(mod) mod.__loader__ = self.ins # Should still return a proper repr. @@ -323,32 +303,6 @@ def contents(self, *args, **kwargs): return super().contents(*args, **kwargs) -class ResourceReaderDefaultsTests(ABCTestHarness): - - SPLIT = make_abc_subclasses(ResourceReader) - - def test_open_resource(self): - with self.assertRaises(FileNotFoundError): - self.ins.open_resource('dummy_file') - - def test_resource_path(self): - with self.assertRaises(FileNotFoundError): - self.ins.resource_path('dummy_file') - - def test_is_resource(self): - with self.assertRaises(FileNotFoundError): - self.ins.is_resource('dummy_file') - - def test_contents(self): - with self.assertRaises(FileNotFoundError): - self.ins.contents() - - -(Frozen_RRDefaultTests, - Source_RRDefaultsTests - ) = test_util.test_both(ResourceReaderDefaultsTests) - - ##### MetaPathFinder concrete methods ########################################## class MetaPathFinderFindModuleTests: @@ -362,14 +316,6 @@ def find_spec(self, fullname, path, target=None): return MetaPathSpecFinder() - def test_find_module(self): - finder = self.finder(None) - path = ['a', 'b', 'c'] - name = 'blah' - with self.assertWarns(DeprecationWarning): - found = finder.find_module(name, path) - self.assertIsNone(found) - def test_find_spec_with_explicit_target(self): loader = object() spec = self.util.spec_from_loader('blah', loader) @@ -399,53 +345,6 @@ def test_spec(self): ) = test_util.test_both(MetaPathFinderFindModuleTests, abc=abc, util=util) -##### PathEntryFinder concrete methods ######################################### -class PathEntryFinderFindLoaderTests: - - @classmethod - def finder(cls, spec): - class PathEntrySpecFinder(cls.abc.PathEntryFinder): - - def find_spec(self, fullname, target=None): - self.called_for = fullname - return spec - - return PathEntrySpecFinder() - - def test_no_spec(self): - finder = self.finder(None) - name = 'blah' - with self.assertWarns(DeprecationWarning): - found = finder.find_loader(name) - self.assertIsNone(found[0]) - self.assertEqual([], found[1]) - self.assertEqual(name, finder.called_for) - - def test_spec_with_loader(self): - loader = object() - spec = self.util.spec_from_loader('blah', loader) - finder = self.finder(spec) - with self.assertWarns(DeprecationWarning): - found = finder.find_loader('blah') - self.assertIs(found[0], spec.loader) - - def test_spec_with_portions(self): - spec = self.machinery.ModuleSpec('blah', None) - paths = ['a', 'b', 'c'] - spec.submodule_search_locations = paths - finder = self.finder(spec) - with self.assertWarns(DeprecationWarning): - found = finder.find_loader('blah') - self.assertIsNone(found[0]) - self.assertEqual(paths, found[1]) - - -(Frozen_PEFFindLoaderTests, - Source_PEFFindLoaderTests - ) = test_util.test_both(PathEntryFinderFindLoaderTests, abc=abc, util=util, - machinery=machinery) - - ##### Loader concrete methods ################################################## class LoaderLoadModuleTests: @@ -716,9 +615,6 @@ def get_data(self, path): def get_filename(self, fullname): return self.path - def module_repr(self, module): - return '' - SPLIT_SOL = make_abc_subclasses(SourceOnlyLoader, 'SourceLoader') @@ -803,13 +699,7 @@ def verify_code(self, code_object): class SourceOnlyLoaderTests(SourceLoaderTestHarness): - - """Test importlib.abc.SourceLoader for source-only loading. - - Reload testing is subsumed by the tests for - importlib.util.module_for_loader. - - """ + """Test importlib.abc.SourceLoader for source-only loading.""" # TODO: RUSTPYTHON @unittest.expectedFailure diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py index 1beb7835d4..ecf2c47c46 100644 --- a/Lib/test/test_importlib/test_api.py +++ b/Lib/test/test_importlib/test_api.py @@ -6,7 +6,6 @@ import os.path import sys -from test import support from test.support import import_helper from test.support import os_helper import types @@ -96,7 +95,8 @@ def load_b(): (Frozen_ImportModuleTests, Source_ImportModuleTests - ) = test_util.test_both(ImportModuleTests, init=init) + ) = test_util.test_both( + ImportModuleTests, init=init, util=util, machinery=machinery) class FindLoaderTests: @@ -104,29 +104,26 @@ class FindLoaderTests: FakeMetaFinder = None def test_sys_modules(self): - # If a module with __loader__ is in sys.modules, then return it. + # If a module with __spec__.loader is in sys.modules, then return it. name = 'some_mod' with test_util.uncache(name): module = types.ModuleType(name) loader = 'a loader!' - module.__loader__ = loader + module.__spec__ = self.machinery.ModuleSpec(name, loader) sys.modules[name] = module - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - found = self.init.find_loader(name) - self.assertEqual(loader, found) + spec = self.util.find_spec(name) + self.assertIsNotNone(spec) + self.assertEqual(spec.loader, loader) def test_sys_modules_loader_is_None(self): - # If sys.modules[name].__loader__ is None, raise ValueError. + # If sys.modules[name].__spec__.loader is None, raise ValueError. name = 'some_mod' with test_util.uncache(name): module = types.ModuleType(name) module.__loader__ = None sys.modules[name] = module with self.assertRaises(ValueError): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.init.find_loader(name) + self.util.find_spec(name) def test_sys_modules_loader_is_not_set(self): # Should raise ValueError @@ -135,24 +132,20 @@ def test_sys_modules_loader_is_not_set(self): with test_util.uncache(name): module = types.ModuleType(name) try: - del module.__loader__ + del module.__spec__.loader except AttributeError: pass sys.modules[name] = module with self.assertRaises(ValueError): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.init.find_loader(name) + self.util.find_spec(name) def test_success(self): # Return the loader found on sys.meta_path. name = 'some_mod' with test_util.uncache(name): with test_util.import_state(meta_path=[self.FakeMetaFinder]): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - warnings.simplefilter('ignore', ImportWarning) - self.assertEqual((name, None), self.init.find_loader(name)) + spec = self.util.find_spec(name) + self.assertEqual((name, (name, None)), (spec.name, spec.loader)) def test_success_path(self): # Searching on a path should work. @@ -160,17 +153,12 @@ def test_success_path(self): path = 'path to some place' with test_util.uncache(name): with test_util.import_state(meta_path=[self.FakeMetaFinder]): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - warnings.simplefilter('ignore', ImportWarning) - self.assertEqual((name, path), - self.init.find_loader(name, path)) + spec = self.util.find_spec(name, path) + self.assertEqual(name, spec.name) def test_nothing(self): # None is returned upon failure to find a loader. - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertIsNone(self.init.find_loader('nevergoingtofindthismodule')) + self.assertIsNone(self.util.find_spec('nevergoingtofindthismodule')) class FindLoaderPEP451Tests(FindLoaderTests): @@ -183,20 +171,8 @@ def find_spec(name, path=None, target=None): (Frozen_FindLoaderPEP451Tests, Source_FindLoaderPEP451Tests - ) = test_util.test_both(FindLoaderPEP451Tests, init=init) - - -class FindLoaderPEP302Tests(FindLoaderTests): - - class FakeMetaFinder: - @staticmethod - def find_module(name, path=None): - return name, path - - -(Frozen_FindLoaderPEP302Tests, - Source_FindLoaderPEP302Tests - ) = test_util.test_both(FindLoaderPEP302Tests, init=init) + ) = test_util.test_both( + FindLoaderPEP451Tests, init=init, util=util, machinery=machinery) class ReloadTests: @@ -301,7 +277,8 @@ def test_reload_namespace_changed(self): name = 'spam' with os_helper.temp_cwd(None) as cwd: with test_util.uncache('spam'): - with import_helper.DirsOnSysPath(cwd): + with test_util.import_state(path=[cwd]): + self.init._bootstrap_external._install(self.init._bootstrap) # Start as a namespace package. self.init.invalidate_caches() bad_path = os.path.join(cwd, name, '__init.py') @@ -380,7 +357,8 @@ def test_module_missing_spec(self): (Frozen_ReloadTests, Source_ReloadTests - ) = test_util.test_both(ReloadTests, init=init, util=util) + ) = test_util.test_both( + ReloadTests, init=init, util=util, machinery=machinery) class InvalidateCacheTests: @@ -390,8 +368,6 @@ def test_method_called(self): class InvalidatingNullFinder: def __init__(self, *ignored): self.called = False - def find_module(self, *args): - return None def invalidate_caches(self): self.called = True @@ -416,7 +392,8 @@ def test_method_lacking(self): (Frozen_InvalidateCacheTests, Source_InvalidateCacheTests - ) = test_util.test_both(InvalidateCacheTests, init=init) + ) = test_util.test_both( + InvalidateCacheTests, init=init, util=util, machinery=machinery) class FrozenImportlibTests(unittest.TestCase): diff --git a/Lib/test/test_importlib/test_files.py b/Lib/test/test_importlib/test_files.py deleted file mode 100644 index b9170d83be..0000000000 --- a/Lib/test/test_importlib/test_files.py +++ /dev/null @@ -1,46 +0,0 @@ -import typing -import unittest - -from importlib import resources -from importlib.abc import Traversable -from . import data01 -from .resources import util - - -class FilesTests: - def test_read_bytes(self): - files = resources.files(self.data) - actual = files.joinpath('utf-8.file').read_bytes() - assert actual == b'Hello, UTF-8 world!\n' - - def test_read_text(self): - files = resources.files(self.data) - actual = files.joinpath('utf-8.file').read_text(encoding='utf-8') - assert actual == 'Hello, UTF-8 world!\n' - - @unittest.skipUnless( - hasattr(typing, 'runtime_checkable'), - "Only suitable when typing supports runtime_checkable", - ) - def test_traversable(self): - assert isinstance(resources.files(self.data), Traversable) - - -class OpenDiskTests(FilesTests, unittest.TestCase): - def setUp(self): - self.data = data01 - - -class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase): - pass - - -class OpenNamespaceTests(FilesTests, unittest.TestCase): - def setUp(self): - from . import namespacedata01 - - self.data = namespacedata01 - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py index 32ed67c308..17cce741cc 100644 --- a/Lib/test/test_importlib/test_locks.py +++ b/Lib/test/test_importlib/test_locks.py @@ -33,6 +33,11 @@ class ModuleLockAsRLockTests: test_repr = None test_locked_repr = None + def tearDown(self): + for splitinit in init.values(): + splitinit._bootstrap._blocking_on.clear() + + LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock for kind, splitinit in init.items()} diff --git a/Lib/test/test_importlib/test_main.py b/Lib/test/test_importlib/test_main.py index d9d067c4b2..81f683799c 100644 --- a/Lib/test/test_importlib/test_main.py +++ b/Lib/test/test_importlib/test_main.py @@ -1,9 +1,10 @@ import re -import json import pickle import unittest import warnings import importlib.metadata +import contextlib +import itertools try: import pyfakefs.fake_filesystem_unittest as ffs @@ -11,6 +12,7 @@ from .stubs import fake_filesystem_unittest as ffs from . import fixtures +from ._context import suppress from importlib.metadata import ( Distribution, EntryPoint, @@ -24,6 +26,13 @@ ) +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter('default', category=DeprecationWarning) + yield ctx + + class BasicTests(fixtures.DistInfoPkg, unittest.TestCase): version_pattern = r'\d+\.\d+(\.\d)?' @@ -39,7 +48,7 @@ def test_for_name_does_not_exist(self): def test_package_not_found_mentions_metadata(self): """ When a package is not found, that could indicate that the - packgae is not installed or that it is installed without + package is not installed or that it is installed without metadata. Ensure the exception mentions metadata to help guide users toward the cause. See #124. """ @@ -48,15 +57,19 @@ def test_package_not_found_mentions_metadata(self): assert "metadata" in str(ctx.exception) - def test_new_style_classes(self): - self.assertIsInstance(Distribution, type) + # expected to fail until ABC is enforced + @suppress(AssertionError) + @suppress_known_deprecation() + def test_abc_enforced(self): + with self.assertRaises(TypeError): + type('DistributionSubclass', (Distribution,), {})() @fixtures.parameterize( dict(name=None), dict(name=''), ) def test_invalid_inputs_to_from_name(self, name): - with self.assertRaises(Exception): + with self.assertRaises(ValueError): Distribution.from_name(name) @@ -174,11 +187,21 @@ def test_metadata_loads_egg_info(self): assert meta['Description'] == 'pôrˈtend' -class DiscoveryTests(fixtures.EggInfoPkg, fixtures.DistInfoPkg, unittest.TestCase): +class DiscoveryTests( + fixtures.EggInfoPkg, + fixtures.EggInfoPkgPipInstalledNoToplevel, + fixtures.EggInfoPkgPipInstalledNoModules, + fixtures.EggInfoPkgSourcesFallback, + fixtures.DistInfoPkg, + unittest.TestCase, +): def test_package_discovery(self): dists = list(distributions()) assert all(isinstance(dist, Distribution) for dist in dists) assert any(dist.metadata['Name'] == 'egginfo-pkg' for dist in dists) + assert any(dist.metadata['Name'] == 'egg_with_module-pkg' for dist in dists) + assert any(dist.metadata['Name'] == 'egg_with_no_modules-pkg' for dist in dists) + assert any(dist.metadata['Name'] == 'sources_fallback-pkg' for dist in dists) assert any(dist.metadata['Name'] == 'distinfo-pkg' for dist in dists) def test_invalid_usage(self): @@ -260,14 +283,6 @@ def test_hashable(self): """EntryPoints should be hashable""" hash(self.ep) - def test_json_dump(self): - """ - json should not expect to be able to dump an EntryPoint - """ - with self.assertRaises(Exception): - with warnings.catch_warnings(record=True): - json.dumps(self.ep) - def test_module(self): assert self.ep.module == 'value' @@ -334,3 +349,79 @@ def test_packages_distributions_neither_toplevel_nor_files(self): prefix=self.site_dir, ) packages_distributions() + + def test_packages_distributions_all_module_types(self): + """ + Test top-level modules detected on a package without 'top-level.txt'. + """ + suffixes = importlib.machinery.all_suffixes() + metadata = dict( + METADATA=""" + Name: all_distributions + Version: 1.0.0 + """, + ) + files = { + 'all_distributions-1.0.0.dist-info': metadata, + } + for i, suffix in enumerate(suffixes): + files.update( + { + f'importable-name {i}{suffix}': '', + f'in_namespace_{i}': { + f'mod{suffix}': '', + }, + f'in_package_{i}': { + '__init__.py': '', + f'mod{suffix}': '', + }, + } + ) + metadata.update(RECORD=fixtures.build_record(files)) + fixtures.build_files(files, prefix=self.site_dir) + + distributions = packages_distributions() + + for i in range(len(suffixes)): + assert distributions[f'importable-name {i}'] == ['all_distributions'] + assert distributions[f'in_namespace_{i}'] == ['all_distributions'] + assert distributions[f'in_package_{i}'] == ['all_distributions'] + + assert not any(name.endswith('.dist-info') for name in distributions) + + +class PackagesDistributionsEggTest( + fixtures.EggInfoPkg, + fixtures.EggInfoPkgPipInstalledNoToplevel, + fixtures.EggInfoPkgPipInstalledNoModules, + fixtures.EggInfoPkgSourcesFallback, + unittest.TestCase, +): + def test_packages_distributions_on_eggs(self): + """ + Test old-style egg packages with a variation of 'top_level.txt', + 'SOURCES.txt', and 'installed-files.txt', available. + """ + distributions = packages_distributions() + + def import_names_from_package(package_name): + return { + import_name + for import_name, package_names in distributions.items() + if package_name in package_names + } + + # egginfo-pkg declares one import ('mod') via top_level.txt + assert import_names_from_package('egginfo-pkg') == {'mod'} + + # egg_with_module-pkg has one import ('egg_with_module') inferred from + # installed-files.txt (top_level.txt is missing) + assert import_names_from_package('egg_with_module-pkg') == {'egg_with_module'} + + # egg_with_no_modules-pkg should not be associated with any import names + # (top_level.txt is empty, and installed-files.txt has no .py files) + assert import_names_from_package('egg_with_no_modules-pkg') == set() + + # sources_fallback-pkg has one import ('sources_fallback') inferred from + # SOURCES.txt (top_level.txt and installed-files.txt is missing) + assert import_names_from_package('sources_fallback-pkg') == {'sources_fallback'} diff --git a/Lib/test/test_importlib/test_metadata_api.py b/Lib/test/test_importlib/test_metadata_api.py index abf568fcca..55c9f8007e 100644 --- a/Lib/test/test_importlib/test_metadata_api.py +++ b/Lib/test/test_importlib/test_metadata_api.py @@ -27,12 +27,14 @@ def suppress_known_deprecation(): class APITests( fixtures.EggInfoPkg, + fixtures.EggInfoPkgPipInstalledNoToplevel, + fixtures.EggInfoPkgPipInstalledNoModules, + fixtures.EggInfoPkgSourcesFallback, fixtures.DistInfoPkg, fixtures.DistInfoPkgWithDot, fixtures.EggInfoFile, unittest.TestCase, ): - version_pattern = r'\d+\.\d+(\.\d)?' def test_retrieves_version_of_self(self): @@ -63,15 +65,28 @@ def test_prefix_not_matched(self): distribution(prefix) def test_for_top_level(self): - self.assertEqual( - distribution('egginfo-pkg').read_text('top_level.txt').strip(), 'mod' - ) + tests = [ + ('egginfo-pkg', 'mod'), + ('egg_with_no_modules-pkg', ''), + ] + for pkg_name, expect_content in tests: + with self.subTest(pkg_name): + self.assertEqual( + distribution(pkg_name).read_text('top_level.txt').strip(), + expect_content, + ) def test_read_text(self): - top_level = [ - path for path in files('egginfo-pkg') if path.name == 'top_level.txt' - ][0] - self.assertEqual(top_level.read_text(), 'mod\n') + tests = [ + ('egginfo-pkg', 'mod\n'), + ('egg_with_no_modules-pkg', '\n'), + ] + for pkg_name, expect_content in tests: + with self.subTest(pkg_name): + top_level = [ + path for path in files(pkg_name) if path.name == 'top_level.txt' + ][0] + self.assertEqual(top_level.read_text(), expect_content) def test_entry_points(self): eps = entry_points() @@ -124,62 +139,6 @@ def test_entry_points_missing_name(self): def test_entry_points_missing_group(self): assert entry_points(group='missing') == () - def test_entry_points_dict_construction(self): - """ - Prior versions of entry_points() returned simple lists and - allowed casting those lists into maps by name using ``dict()``. - Capture this now deprecated use-case. - """ - with suppress_known_deprecation() as caught: - eps = dict(entry_points(group='entries')) - - assert 'main' in eps - assert eps['main'] == entry_points(group='entries')['main'] - - # check warning - expected = next(iter(caught)) - assert expected.category is DeprecationWarning - assert "Construction of dict of EntryPoints is deprecated" in str(expected) - - def test_entry_points_by_index(self): - """ - Prior versions of Distribution.entry_points would return a - tuple that allowed access by index. - Capture this now deprecated use-case - See python/importlib_metadata#300 and bpo-44246. - """ - eps = distribution('distinfo-pkg').entry_points - with suppress_known_deprecation() as caught: - eps[0] - - # check warning - expected = next(iter(caught)) - assert expected.category is DeprecationWarning - assert "Accessing entry points by index is deprecated" in str(expected) - - def test_entry_points_groups_getitem(self): - """ - Prior versions of entry_points() returned a dict. Ensure - that callers using '.__getitem__()' are supported but warned to - migrate. - """ - with suppress_known_deprecation(): - entry_points()['entries'] == entry_points(group='entries') - - with self.assertRaises(KeyError): - entry_points()['missing'] - - def test_entry_points_groups_get(self): - """ - Prior versions of entry_points() returned a dict. Ensure - that callers using '.get()' are supported but warned to - migrate. - """ - with suppress_known_deprecation(): - entry_points().get('missing', 'default') == 'default' - entry_points().get('entries', 'default') == entry_points()['entries'] - entry_points().get('missing', ()) == () - # TODO: RUSTPYTHON @unittest.expectedFailure def test_entry_points_allows_no_attributes(self): @@ -195,6 +154,28 @@ def test_metadata_for_this_package(self): classifiers = md.get_all('Classifier') assert 'Topic :: Software Development :: Libraries' in classifiers + def test_missing_key_legacy(self): + """ + Requesting a missing key will still return None, but warn. + """ + md = metadata('distinfo-pkg') + with suppress_known_deprecation(): + assert md['does-not-exist'] is None + + def test_get_key(self): + """ + Getting a key gets the key. + """ + md = metadata('egginfo-pkg') + assert md.get('Name') == 'egginfo-pkg' + + def test_get_missing_key(self): + """ + Requesting a missing key will return None. + """ + md = metadata('distinfo-pkg') + assert md.get('does-not-exist') is None + @staticmethod def _test_files(files): root = files[0].root @@ -217,6 +198,9 @@ def test_files_dist_info(self): def test_files_egg_info(self): self._test_files(files('egginfo-pkg')) + self._test_files(files('egg_with_module-pkg')) + self._test_files(files('egg_with_no_modules-pkg')) + self._test_files(files('sources_fallback-pkg')) def test_version_egg_info_file(self): self.assertEqual(version('egginfo-file'), '0.1') diff --git a/Lib/test/test_importlib/test_namespace_pkgs.py b/Lib/test/test_importlib/test_namespace_pkgs.py index cd08498545..65428c3d3e 100644 --- a/Lib/test/test_importlib/test_namespace_pkgs.py +++ b/Lib/test/test_importlib/test_namespace_pkgs.py @@ -79,12 +79,9 @@ def test_cant_import_other(self): with self.assertRaises(ImportError): import foo.two - def test_module_repr(self): + def test_simple_repr(self): import foo.one - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - self.assertEqual(foo.__spec__.loader.module_repr(foo), - "") + assert repr(foo).startswith("'.format(module.__name__) - self.module.__loader__ = Loader() - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, '') - - def test_module___loader___module_repr_bad(self): - class Loader(TestLoader): - def module_repr(self, module): - raise Exception - self.module.__loader__ = Loader() - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ')>'.format('spam')) - - def test_module___spec__(self): - origin = 'in a hole, in the ground' - self.spec.origin = origin - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('spam', origin)) - - def test_module___spec___location(self): - location = 'in_a_galaxy_far_far_away.py' - self.spec.origin = location - self.spec._set_fileattr = True - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ''.format('spam', location)) - - def test_module___spec___no_origin(self): - self.spec.loader = TestLoader() - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ')>'.format('spam')) - - def test_module___spec___no_origin_no_loader(self): - self.spec.loader = None - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('spam')) - - def test_module_no_name(self): - del self.module.__name__ - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('?')) - - def test_module_with_file(self): - filename = 'e/i/e/i/o/spam.py' - self.module.__file__ = filename - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ''.format('spam', filename)) - - def test_module_no_file(self): - self.module.__loader__ = TestLoader() - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ')>'.format('spam')) - - def test_module_no_file_no_loader(self): - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('spam')) - - -(Frozen_ModuleReprTests, - Source_ModuleReprTests - ) = test_util.test_both(ModuleReprTests, init=init, util=util, - machinery=machinery) - - class FactoryTests: def setUp(self): diff --git a/Lib/test/test_importlib/test_threaded_import.py b/Lib/test/test_importlib/test_threaded_import.py index 49c02484b7..148b2e4370 100644 --- a/Lib/test/test_importlib/test_threaded_import.py +++ b/Lib/test/test_importlib/test_threaded_import.py @@ -16,7 +16,7 @@ import unittest from unittest import mock from test.support import verbose -from test.support.import_helper import forget +from test.support.import_helper import forget, mock_register_at_fork from test.support.os_helper import (TESTFN, unlink, rmtree) from test.support import script_helper, threading_helper @@ -42,12 +42,6 @@ def task(N, done, done_tasks, errors): if finished: done.set() -def mock_register_at_fork(func): - # bpo-30599: Mock os.register_at_fork() when importing the random module, - # since this function doesn't allow to unregister callbacks and would leak - # memory. - return mock.patch('os.register_at_fork', create=True)(func) - # Create a circular import structure: A -> C -> B -> D -> A # NOTE: `time` is already loaded and therefore doesn't threaten to deadlock. @@ -251,7 +245,8 @@ def target(): self.addCleanup(forget, TESTFN) self.addCleanup(rmtree, '__pycache__') importlib.invalidate_caches() - __import__(TESTFN) + with threading_helper.wait_threads_exit(): + __import__(TESTFN) del sys.modules[TESTFN] @unittest.skip("TODO: RUSTPYTHON; hang") diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py index 6c791fc012..dc27e4aa99 100644 --- a/Lib/test/test_importlib/test_util.py +++ b/Lib/test/test_importlib/test_util.py @@ -8,14 +8,29 @@ import importlib.util import os import pathlib +import re import string import sys from test import support +import textwrap import types import unittest import unittest.mock import warnings +try: + import _testsinglephase +except ImportError: + _testsinglephase = None +try: + import _testmultiphase +except ImportError: + _testmultiphase = None +try: + import _xxsubinterpreters as _interpreters +except ModuleNotFoundError: + _interpreters = None + class DecodeSourceBytesTests: @@ -127,247 +142,6 @@ def test___cached__(self): util=importlib_util) -class ModuleForLoaderTests: - - """Tests for importlib.util.module_for_loader.""" - - @classmethod - def module_for_loader(cls, func): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - return cls.util.module_for_loader(func) - - def test_warning(self): - # Should raise a PendingDeprecationWarning when used. - with warnings.catch_warnings(): - warnings.simplefilter('error', DeprecationWarning) - with self.assertRaises(DeprecationWarning): - func = self.util.module_for_loader(lambda x: x) - - def return_module(self, name): - fxn = self.module_for_loader(lambda self, module: module) - return fxn(self, name) - - def raise_exception(self, name): - def to_wrap(self, module): - raise ImportError - fxn = self.module_for_loader(to_wrap) - try: - fxn(self, name) - except ImportError: - pass - - def test_new_module(self): - # Test that when no module exists in sys.modules a new module is - # created. - module_name = 'a.b.c' - with util.uncache(module_name): - module = self.return_module(module_name) - self.assertIn(module_name, sys.modules) - self.assertIsInstance(module, types.ModuleType) - self.assertEqual(module.__name__, module_name) - - def test_reload(self): - # Test that a module is reused if already in sys.modules. - class FakeLoader: - def is_package(self, name): - return True - @self.module_for_loader - def load_module(self, module): - return module - name = 'a.b.c' - module = types.ModuleType('a.b.c') - module.__loader__ = 42 - module.__package__ = 42 - with util.uncache(name): - sys.modules[name] = module - loader = FakeLoader() - returned_module = loader.load_module(name) - self.assertIs(returned_module, sys.modules[name]) - self.assertEqual(module.__loader__, loader) - self.assertEqual(module.__package__, name) - - def test_new_module_failure(self): - # Test that a module is removed from sys.modules if added but an - # exception is raised. - name = 'a.b.c' - with util.uncache(name): - self.raise_exception(name) - self.assertNotIn(name, sys.modules) - - def test_reload_failure(self): - # Test that a failure on reload leaves the module in-place. - name = 'a.b.c' - module = types.ModuleType(name) - with util.uncache(name): - sys.modules[name] = module - self.raise_exception(name) - self.assertIs(module, sys.modules[name]) - - def test_decorator_attrs(self): - def fxn(self, module): pass - wrapped = self.module_for_loader(fxn) - self.assertEqual(wrapped.__name__, fxn.__name__) - self.assertEqual(wrapped.__qualname__, fxn.__qualname__) - - def test_false_module(self): - # If for some odd reason a module is considered false, still return it - # from sys.modules. - class FalseModule(types.ModuleType): - def __bool__(self): return False - - name = 'mod' - module = FalseModule(name) - with util.uncache(name): - self.assertFalse(module) - sys.modules[name] = module - given = self.return_module(name) - self.assertIs(given, module) - - def test_attributes_set(self): - # __name__, __loader__, and __package__ should be set (when - # is_package() is defined; undefined implicitly tested elsewhere). - class FakeLoader: - def __init__(self, is_package): - self._pkg = is_package - def is_package(self, name): - return self._pkg - @self.module_for_loader - def load_module(self, module): - return module - - name = 'pkg.mod' - with util.uncache(name): - loader = FakeLoader(False) - module = loader.load_module(name) - self.assertEqual(module.__name__, name) - self.assertIs(module.__loader__, loader) - self.assertEqual(module.__package__, 'pkg') - - name = 'pkg.sub' - with util.uncache(name): - loader = FakeLoader(True) - module = loader.load_module(name) - self.assertEqual(module.__name__, name) - self.assertIs(module.__loader__, loader) - self.assertEqual(module.__package__, name) - - -(Frozen_ModuleForLoaderTests, - Source_ModuleForLoaderTests - ) = util.test_both(ModuleForLoaderTests, util=importlib_util) - - -class SetPackageTests: - - """Tests for importlib.util.set_package.""" - - def verify(self, module, expect): - """Verify the module has the expected value for __package__ after - passing through set_package.""" - fxn = lambda: module - wrapped = self.util.set_package(fxn) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - wrapped() - self.assertTrue(hasattr(module, '__package__')) - self.assertEqual(expect, module.__package__) - - def test_top_level(self): - # __package__ should be set to the empty string if a top-level module. - # Implicitly tests when package is set to None. - module = types.ModuleType('module') - module.__package__ = None - self.verify(module, '') - - def test_package(self): - # Test setting __package__ for a package. - module = types.ModuleType('pkg') - module.__path__ = [''] - module.__package__ = None - self.verify(module, 'pkg') - - def test_submodule(self): - # Test __package__ for a module in a package. - module = types.ModuleType('pkg.mod') - module.__package__ = None - self.verify(module, 'pkg') - - def test_setting_if_missing(self): - # __package__ should be set if it is missing. - module = types.ModuleType('mod') - if hasattr(module, '__package__'): - delattr(module, '__package__') - self.verify(module, '') - - def test_leaving_alone(self): - # If __package__ is set and not None then leave it alone. - for value in (True, False): - module = types.ModuleType('mod') - module.__package__ = value - self.verify(module, value) - - def test_decorator_attrs(self): - def fxn(module): pass - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - wrapped = self.util.set_package(fxn) - self.assertEqual(wrapped.__name__, fxn.__name__) - self.assertEqual(wrapped.__qualname__, fxn.__qualname__) - - -(Frozen_SetPackageTests, - Source_SetPackageTests - ) = util.test_both(SetPackageTests, util=importlib_util) - - -class SetLoaderTests: - - """Tests importlib.util.set_loader().""" - - @property - def DummyLoader(self): - # Set DummyLoader on the class lazily. - class DummyLoader: - @self.util.set_loader - def load_module(self, module): - return self.module - self.__class__.DummyLoader = DummyLoader - return DummyLoader - - def test_no_attribute(self): - loader = self.DummyLoader() - loader.module = types.ModuleType('blah') - try: - del loader.module.__loader__ - except AttributeError: - pass - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(loader, loader.load_module('blah').__loader__) - - def test_attribute_is_None(self): - loader = self.DummyLoader() - loader.module = types.ModuleType('blah') - loader.module.__loader__ = None - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(loader, loader.load_module('blah').__loader__) - - def test_not_reset(self): - loader = self.DummyLoader() - loader.module = types.ModuleType('blah') - loader.module.__loader__ = 42 - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(42, loader.load_module('blah').__loader__) - - -(Frozen_SetLoaderTests, - Source_SetLoaderTests - ) = util.test_both(SetLoaderTests, util=importlib_util) - - class ResolveNameTests: """Tests importlib.util.resolve_name().""" @@ -877,7 +651,7 @@ def test_magic_number(self): # stakeholders such as OS package maintainers must be notified # in advance. Such exceptional releases will then require an # adjustment to this test case. - EXPECTED_MAGIC_NUMBER = 3495 + EXPECTED_MAGIC_NUMBER = 3531 actual = int.from_bytes(importlib.util.MAGIC_NUMBER[:2], 'little') msg = ( @@ -895,5 +669,111 @@ def test_magic_number(self): self.assertEqual(EXPECTED_MAGIC_NUMBER, actual, msg) +@unittest.skipIf(_interpreters is None, 'subinterpreters required') +class IncompatibleExtensionModuleRestrictionsTests(unittest.TestCase): + + ERROR = re.compile("^: module (.*) does not support loading in subinterpreters") + + def run_with_own_gil(self, script): + interpid = _interpreters.create(isolated=True) + try: + _interpreters.run_string(interpid, script) + except _interpreters.RunFailedError as exc: + if m := self.ERROR.match(str(exc)): + modname, = m.groups() + raise ImportError(modname) + + def run_with_shared_gil(self, script): + interpid = _interpreters.create(isolated=False) + try: + _interpreters.run_string(interpid, script) + except _interpreters.RunFailedError as exc: + if m := self.ERROR.match(str(exc)): + modname, = m.groups() + raise ImportError(modname) + + @unittest.skipIf(_testsinglephase is None, "test requires _testsinglephase module") + def test_single_phase_init_module(self): + script = textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=True): + import _testsinglephase + ''') + with self.subTest('check disabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check disabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + script = textwrap.dedent(f''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=False): + import _testsinglephase + ''') + with self.subTest('check enabled, shared GIL'): + with self.assertRaises(ImportError): + self.run_with_shared_gil(script) + with self.subTest('check enabled, per-interpreter GIL'): + with self.assertRaises(ImportError): + self.run_with_own_gil(script) + + @unittest.skipIf(_testmultiphase is None, "test requires _testmultiphase module") + def test_incomplete_multi_phase_init_module(self): + prescript = textwrap.dedent(f''' + from importlib.util import spec_from_loader, module_from_spec + from importlib.machinery import ExtensionFileLoader + + name = '_test_shared_gil_only' + filename = {_testmultiphase.__file__!r} + loader = ExtensionFileLoader(name, filename) + spec = spec_from_loader(name, loader) + + ''') + + script = prescript + textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=True): + module = module_from_spec(spec) + loader.exec_module(module) + ''') + with self.subTest('check disabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check disabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + script = prescript + textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=False): + module = module_from_spec(spec) + loader.exec_module(module) + ''') + with self.subTest('check enabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check enabled, per-interpreter GIL'): + with self.assertRaises(ImportError): + self.run_with_own_gil(script) + + @unittest.skipIf(_testmultiphase is None, "test requires _testmultiphase module") + def test_complete_multi_phase_init_module(self): + script = textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=True): + import _testmultiphase + ''') + with self.subTest('check disabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check disabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + script = textwrap.dedent(f''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=False): + import _testmultiphase + ''') + with self.subTest('check enabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check enabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/test_windows.py b/Lib/test/test_importlib/test_windows.py index 051193fae0..f8a9ead9ac 100644 --- a/Lib/test/test_importlib/test_windows.py +++ b/Lib/test/test_importlib/test_windows.py @@ -92,30 +92,16 @@ class WindowsRegistryFinderTests: def test_find_spec_missing(self): spec = self.machinery.WindowsRegistryFinder.find_spec('spam') - self.assertIs(spec, None) - - def test_find_module_missing(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.WindowsRegistryFinder.find_module('spam') - self.assertIs(loader, None) + self.assertIsNone(spec) def test_module_found(self): with setup_module(self.machinery, self.test_module): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module) spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module) - self.assertIsNot(loader, None) - self.assertIsNot(spec, None) + self.assertIsNotNone(spec) def test_module_not_found(self): with setup_module(self.machinery, self.test_module, path="."): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module) spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module) - self.assertIsNone(loader) self.assertIsNone(spec) (Frozen_WindowsRegistryFinderTests, diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py index 0b6dcc5eaf..c25be096e5 100644 --- a/Lib/test/test_importlib/util.py +++ b/Lib/test/test_importlib/util.py @@ -27,7 +27,7 @@ EXTENSIONS.ext = None EXTENSIONS.filename = None EXTENSIONS.file_path = None -EXTENSIONS.name = '_testcapi' +EXTENSIONS.name = '_testsinglephase' def _extension_details(): global EXTENSIONS @@ -131,9 +131,8 @@ def uncache(*names): """ for name in names: - if name in ('sys', 'marshal', 'imp'): - raise ValueError( - "cannot uncache {0}".format(name)) + if name in ('sys', 'marshal'): + raise ValueError("cannot uncache {}".format(name)) try: del sys.modules[name] except KeyError: @@ -195,8 +194,7 @@ def import_state(**kwargs): new_value = default setattr(sys, attr, new_value) if len(kwargs): - raise ValueError( - 'unrecognized arguments: {0}'.format(kwargs.keys())) + raise ValueError('unrecognized arguments: {}'.format(kwargs)) yield finally: for attr, value in originals.items(): @@ -244,30 +242,6 @@ def __exit__(self, *exc_info): self._uncache.__exit__(None, None, None) -class mock_modules(_ImporterMock): - - """Importer mock using PEP 302 APIs.""" - - def find_module(self, fullname, path=None): - if fullname not in self.modules: - return None - else: - return self - - def load_module(self, fullname): - if fullname not in self.modules: - raise ImportError - else: - sys.modules[fullname] = self.modules[fullname] - if fullname in self.module_code: - try: - self.module_code[fullname]() - except Exception: - del sys.modules[fullname] - raise - return self.modules[fullname] - - class mock_spec(_ImporterMock): """Importer mock using PEP 451 APIs.""" From 5b4a585837c8cd14dd27b67e7be6c106150946c0 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 4 Oct 2023 03:03:12 +0900 Subject: [PATCH 04/19] Mark failings tests from importlib --- Lib/test/test_importlib/import_/test_helpers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Lib/test/test_importlib/import_/test_helpers.py b/Lib/test/test_importlib/import_/test_helpers.py index 550f88d1d7..28cdc0e526 100644 --- a/Lib/test/test_importlib/import_/test_helpers.py +++ b/Lib/test/test_importlib/import_/test_helpers.py @@ -126,6 +126,8 @@ def test_gh86298_loader_is_none_and_spec_loader_is_none(self): ValueError, _bootstrap_external._bless_my_loader, bar.__dict__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_gh86298_no_spec(self): bar = ModuleType('bar') bar.__loader__ = object() @@ -135,6 +137,8 @@ def test_gh86298_no_spec(self): DeprecationWarning, _bootstrap_external._bless_my_loader, bar.__dict__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_gh86298_spec_is_none(self): bar = ModuleType('bar') bar.__loader__ = object() @@ -144,6 +148,8 @@ def test_gh86298_spec_is_none(self): DeprecationWarning, _bootstrap_external._bless_my_loader, bar.__dict__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_gh86298_no_spec_loader(self): bar = ModuleType('bar') bar.__loader__ = object() @@ -153,6 +159,8 @@ def test_gh86298_no_spec_loader(self): DeprecationWarning, _bootstrap_external._bless_my_loader, bar.__dict__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_gh86298_loader_and_spec_loader_disagree(self): bar = ModuleType('bar') bar.__loader__ = object() From 48ca7a771e0e6bb037313da08f5ea8b00886e847 Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Wed, 4 Oct 2023 00:27:37 +0900 Subject: [PATCH 05/19] Update test.support from Python3.12 --- Lib/test/support/__init__.py | 435 +++++++++++++++++++++++---- Lib/test/support/bytecode_helper.py | 101 +++++++ Lib/test/support/import_helper.py | 28 ++ Lib/test/support/interpreters.py | 23 +- Lib/test/support/os_helper.py | 48 ++- Lib/test/support/socket_helper.py | 77 ++++- Lib/test/support/testresult.py | 10 +- Lib/test/support/threading_helper.py | 27 +- Lib/test/support/warnings_helper.py | 2 +- Lib/test/test_support.py | 92 +++++- 10 files changed, 733 insertions(+), 110 deletions(-) diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index e9736cd5ba..975ff21101 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -4,13 +4,16 @@ raise ImportError('support must be imported from the test package') import contextlib +import dataclasses import functools import getpass +import opcode import os import re import stat import sys import sysconfig +import textwrap import time import types import unittest @@ -19,11 +22,6 @@ from .testresult import get_test_runner -try: - from _testcapi import unicode_legacy_string -except ImportError: - unicode_legacy_string = None - __all__ = [ # globals "PIPE_MAX_SIZE", "verbose", "max_memuse", "use_resources", "failfast", @@ -36,7 +34,7 @@ "is_resource_enabled", "requires", "requires_freebsd_version", "requires_linux_version", "requires_mac_ver", "check_syntax_error", - "BasicTestRunner", "run_unittest", "run_doctest", + "run_unittest", "run_doctest", "requires_gzip", "requires_bz2", "requires_lzma", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "requires_zlib", @@ -46,9 +44,12 @@ "anticipate_failure", "load_package_tests", "detect_api_mismatch", "check__all__", "skip_if_buggy_ucrt_strfptime", "check_disallow_instantiation", "check_sanitizer", "skip_if_sanitizer", + "requires_limited_api", "requires_specialization", # sys "is_jython", "is_android", "is_emscripten", "is_wasi", "check_impl_detail", "unix_shell", "setswitchinterval", + # os + "get_pagesize", # network "open_urlresource", # processes @@ -59,6 +60,8 @@ "run_with_tz", "PGO", "missing_compiler_executable", "ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST", "LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT", + "Py_DEBUG", "EXCEEDS_RECURSION_LIMIT", "C_RECURSION_LIMIT", + "skip_on_s390x", ] @@ -116,17 +119,20 @@ class Error(Exception): class TestFailed(Error): """Test failed.""" + def __init__(self, msg, *args, stats=None): + self.msg = msg + self.stats = stats + super().__init__(msg, *args) + + def __str__(self): + return self.msg class TestFailedWithDetails(TestFailed): """Test failed.""" - def __init__(self, msg, errors, failures): - self.msg = msg + def __init__(self, msg, errors, failures, stats): self.errors = errors self.failures = failures - super().__init__(msg, errors, failures) - - def __str__(self): - return self.msg + super().__init__(msg, errors, failures, stats=stats) class TestDidNotRun(Error): """Test did not run any subtests.""" @@ -408,7 +414,7 @@ def check_sanitizer(*, address=False, memory=False, ub=False): ) address_sanitizer = ( '-fsanitize=address' in _cflags or - '--with-memory-sanitizer' in _config_args + '--with-address-sanitizer' in _config_args ) ub_sanitizer = ( '-fsanitize=undefined' in _cflags or @@ -500,9 +506,16 @@ def has_no_debug_ranges(): def requires_debug_ranges(reason='requires co_positions / debug_ranges'): return unittest.skipIf(has_no_debug_ranges(), reason) -requires_legacy_unicode_capi = unittest.skipUnless(unicode_legacy_string, - 'requires legacy Unicode C API') +def requires_legacy_unicode_capi(): + try: + from _testcapi import unicode_legacy_string + except ImportError: + unicode_legacy_string = None + + return unittest.skipUnless(unicode_legacy_string, + 'requires legacy Unicode C API') +# Is not actually used in tests, but is kept for compatibility. is_jython = sys.platform.startswith('java') is_android = hasattr(sys, 'getandroidapilevel') @@ -578,7 +591,8 @@ def darwin_malloc_err_warning(test_name): msg = ' NOTICE ' detail = (f'{test_name} may generate "malloc can\'t allocate region"\n' 'warnings on macOS systems. This behavior is known. Do not\n' - 'report a bug unless tests are also failing. See bpo-40928.') + 'report a bug unless tests are also failing.\n' + 'See https://github.com/python/cpython/issues/85100') padding, _ = shutil.get_terminal_size() print(msg.center(padding, '-')) @@ -612,6 +626,14 @@ def sortdict(dict): withcommas = ", ".join(reprpairs) return "{%s}" % withcommas + +def run_code(code: str) -> dict[str, object]: + """Run a piece of code after dedenting it, and return its global namespace.""" + ns = {} + exec(textwrap.dedent(code), ns) + return ns + + def check_syntax_error(testcase, statement, errtext='', *, lineno=None, offset=None): with testcase.assertRaisesRegex(SyntaxError, errtext) as cm: compile(statement, '', 'exec') @@ -994,12 +1016,6 @@ def wrapper(self): #======================================================================= # unittest integration. -class BasicTestRunner: - def run(self, test): - result = unittest.TestResult() - test(result) - return result - def _id(obj): return obj @@ -1078,6 +1094,18 @@ def refcount_test(test): return no_tracing(cpython_only(test)) +def requires_limited_api(test): + try: + import _testcapi + except ImportError: + return unittest.skip('needs _testcapi module')(test) + return unittest.skipUnless( + _testcapi.LIMITED_API_AVAILABLE, 'needs Limited API support')(test) + +def requires_specialization(test): + return unittest.skipUnless( + opcode.ENABLE_SPECIALIZATION, "requires specialization")(test) + def _filter_suite(suite, pred): """Recursively filter test cases in a suite based on a predicate.""" newtests = [] @@ -1090,6 +1118,29 @@ def _filter_suite(suite, pred): newtests.append(test) suite._tests = newtests +@dataclasses.dataclass(slots=True) +class TestStats: + tests_run: int = 0 + failures: int = 0 + skipped: int = 0 + + @staticmethod + def from_unittest(result): + return TestStats(result.testsRun, + len(result.failures), + len(result.skipped)) + + @staticmethod + def from_doctest(results): + return TestStats(results.attempted, + results.failed) + + def accumulate(self, stats): + self.tests_run += stats.tests_run + self.failures += stats.failures + self.skipped += stats.skipped + + def _run_suite(suite): """Run tests from a unittest.TestSuite-derived class.""" runner = get_test_runner(sys.stdout, @@ -1101,9 +1152,10 @@ def _run_suite(suite): if junit_xml_list is not None: junit_xml_list.append(result.get_xml_element()) - if not result.testsRun and not result.skipped: + if not result.testsRun and not result.skipped and not result.errors: raise TestDidNotRun if not result.wasSuccessful(): + stats = TestStats.from_unittest(result) if len(result.errors) == 1 and not result.failures: err = result.errors[0][1] elif len(result.failures) == 1 and not result.errors: @@ -1113,7 +1165,8 @@ def _run_suite(suite): if not verbose: err += "; run in verbose mode for details" errors = [(str(tc), exc_str) for tc, exc_str in result.errors] failures = [(str(tc), exc_str) for tc, exc_str in result.failures] - raise TestFailedWithDetails(err, errors, failures) + raise TestFailedWithDetails(err, errors, failures, stats=stats) + return result # By default, don't filter tests @@ -1144,7 +1197,6 @@ def _is_full_match_test(pattern): def set_match_tests(accept_patterns=None, ignore_patterns=None): global _match_test_func, _accept_test_patterns, _ignore_test_patterns - if accept_patterns is None: accept_patterns = () if ignore_patterns is None: @@ -1222,7 +1274,7 @@ def run_unittest(*classes): else: suite.addTest(loader.loadTestsFromTestCase(cls)) _filter_suite(suite, match_test) - _run_suite(suite) + return _run_suite(suite) #======================================================================= # Check for the presence of docstrings. @@ -1262,13 +1314,18 @@ def run_doctest(module, verbosity=None, optionflags=0): else: verbosity = None - f, t = doctest.testmod(module, verbose=verbosity, optionflags=optionflags) - if f: - raise TestFailed("%d of %d doctests failed" % (f, t)) + results = doctest.testmod(module, + verbose=verbosity, + optionflags=optionflags) + if results.failed: + stats = TestStats.from_doctest(results) + raise TestFailed(f"{results.failed} of {results.attempted} " + f"doctests failed", + stats=stats) if verbose: print('doctest (%s) ... %d tests with zero failures' % - (module.__name__, t)) - return f, t + (module.__name__, results.attempted)) + return results #======================================================================= @@ -1792,6 +1849,25 @@ def run_in_subinterp(code): Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc module is enabled. """ + _check_tracemalloc() + import _testcapi + return _testcapi.run_in_subinterp(code) + + +def run_in_subinterp_with_config(code, *, own_gil=None, **config): + """ + Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc + module is enabled. + """ + _check_tracemalloc() + import _testcapi + if own_gil is not None: + assert 'gil' not in config, (own_gil, config) + config['gil'] = 2 if own_gil else 1 + return _testcapi.run_in_subinterp_with_config(code, **config) + + +def _check_tracemalloc(): # Issue #10915, #15751: PyGILState_*() functions don't work with # sub-interpreters, the tracemalloc module uses these functions internally try: @@ -1803,8 +1879,6 @@ def run_in_subinterp(code): raise unittest.SkipTest("run_in_subinterp() cannot be used " "if tracemalloc module is tracing " "memory allocations") - import _testcapi - return _testcapi.run_in_subinterp(code) # TODO: RUSTPYTHON (comment out before) @@ -1836,15 +1910,16 @@ def missing_compiler_executable(cmd_names=[]): missing. """ - # TODO (PEP 632): alternate check without using distutils - from distutils import ccompiler, sysconfig, spawn, errors + from setuptools._distutils import ccompiler, sysconfig, spawn + from setuptools import errors + compiler = ccompiler.new_compiler() sysconfig.customize_compiler(compiler) if compiler.compiler_type == "msvc": # MSVC has no executables, so check whether initialization succeeds try: compiler.initialize() - except errors.DistutilsPlatformError: + except errors.PlatformError: return "msvc" for name in compiler.executables: if cmd_names and name not in cmd_names: @@ -1875,6 +1950,18 @@ def setswitchinterval(interval): return sys.setswitchinterval(interval) +def get_pagesize(): + """Get size of a page in bytes.""" + try: + page_size = os.sysconf('SC_PAGESIZE') + except (ValueError, AttributeError): + try: + page_size = os.sysconf('SC_PAGE_SIZE') + except (ValueError, AttributeError): + page_size = 4096 + return page_size + + @contextlib.contextmanager def disable_faulthandler(): import faulthandler @@ -2092,31 +2179,26 @@ def wait_process(pid, *, exitcode, timeout=None): if timeout is None: timeout = LONG_TIMEOUT - t0 = time.monotonic() - sleep = 0.001 - max_sleep = 0.1 - while True: + + start_time = time.monotonic() + for _ in sleeping_retry(timeout, error=False): pid2, status = os.waitpid(pid, os.WNOHANG) if pid2 != 0: break - # process is still running - - dt = time.monotonic() - t0 - if dt > timeout: - try: - os.kill(pid, signal.SIGKILL) - os.waitpid(pid, 0) - except OSError: - # Ignore errors like ChildProcessError or PermissionError - pass - - raise AssertionError(f"process {pid} is still running " - f"after {dt:.1f} seconds") + # rety: the process is still running + else: + try: + os.kill(pid, signal.SIGKILL) + os.waitpid(pid, 0) + except OSError: + # Ignore errors like ChildProcessError or PermissionError + pass - sleep = min(sleep * 2, max_sleep) - time.sleep(sleep) + dt = time.monotonic() - start_time + raise AssertionError(f"process {pid} is still running " + f"after {dt:.1f} seconds") else: - # Windows implementation + # Windows implementation: don't support timeout :-( pid2, status = os.waitpid(pid, 0) exitcode2 = os.waitstatus_to_exitcode(status) @@ -2168,20 +2250,61 @@ def check_disallow_instantiation(testcase, tp, *args, **kwds): msg = f"cannot create '{re.escape(qualname)}' instances" testcase.assertRaisesRegex(TypeError, msg, tp, *args, **kwds) +def get_recursion_depth(): + """Get the recursion depth of the caller function. + + In the __main__ module, at the module level, it should be 1. + """ + try: + import _testinternalcapi + depth = _testinternalcapi.get_recursion_depth() + except (ImportError, RecursionError) as exc: + # sys._getframe() + frame.f_back implementation. + try: + depth = 0 + frame = sys._getframe() + while frame is not None: + depth += 1 + frame = frame.f_back + finally: + # Break any reference cycles. + frame = None + + # Ignore get_recursion_depth() frame. + return max(depth - 1, 1) + +def get_recursion_available(): + """Get the number of available frames before RecursionError. + + It depends on the current recursion depth of the caller function and + sys.getrecursionlimit(). + """ + limit = sys.getrecursionlimit() + depth = get_recursion_depth() + return limit - depth + @contextlib.contextmanager -def infinite_recursion(max_depth=75): +def set_recursion_limit(limit): + """Temporarily change the recursion limit.""" + original_limit = sys.getrecursionlimit() + try: + sys.setrecursionlimit(limit) + yield + finally: + sys.setrecursionlimit(original_limit) + +def infinite_recursion(max_depth=100): """Set a lower limit for tests that interact with infinite recursions (e.g test_ast.ASTHelpers_Test.test_recursion_direct) since on some debug windows builds, due to not enough functions being inlined the stack size might not handle the default recursion limit (1000). See bpo-11105 for details.""" - - original_depth = sys.getrecursionlimit() - try: - sys.setrecursionlimit(max_depth) - yield - finally: - sys.setrecursionlimit(original_depth) + if max_depth < 3: + raise ValueError("max_depth must be at least 3, got {max_depth}") + depth = get_recursion_depth() + depth = max(depth - 1, 1) # Ignore infinite_recursion() frame. + limit = depth + max_depth + return set_recursion_limit(limit) def ignore_deprecations_from(module: str, *, like: str) -> object: token = object() @@ -2230,6 +2353,180 @@ def requires_venv_with_pip(): return unittest.skipUnless(ctypes, 'venv: pip requires ctypes') +@functools.cache +def _findwheel(pkgname): + """Try to find a wheel with the package specified as pkgname. + + If set, the wheels are searched for in WHEEL_PKG_DIR (see ensurepip). + Otherwise, they are searched for in the test directory. + """ + wheel_dir = sysconfig.get_config_var('WHEEL_PKG_DIR') or TEST_HOME_DIR + filenames = os.listdir(wheel_dir) + filenames = sorted(filenames, reverse=True) # approximate "newest" first + for filename in filenames: + # filename is like 'setuptools-67.6.1-py3-none-any.whl' + if not filename.endswith(".whl"): + continue + prefix = pkgname + '-' + if filename.startswith(prefix): + return os.path.join(wheel_dir, filename) + raise FileNotFoundError(f"No wheel for {pkgname} found in {wheel_dir}") + + +# Context manager that creates a virtual environment, install setuptools and wheel in it +# and returns the path to the venv directory and the path to the python executable +@contextlib.contextmanager +def setup_venv_with_pip_setuptools_wheel(venv_dir): + import subprocess + from .os_helper import temp_cwd + + with temp_cwd() as temp_dir: + # Create virtual environment to get setuptools + cmd = [sys.executable, '-X', 'dev', '-m', 'venv', venv_dir] + if verbose: + print() + print('Run:', ' '.join(cmd)) + subprocess.run(cmd, check=True) + + venv = os.path.join(temp_dir, venv_dir) + + # Get the Python executable of the venv + python_exe = os.path.basename(sys.executable) + if sys.platform == 'win32': + python = os.path.join(venv, 'Scripts', python_exe) + else: + python = os.path.join(venv, 'bin', python_exe) + + cmd = [python, '-X', 'dev', + '-m', 'pip', 'install', + _findwheel('setuptools'), + _findwheel('wheel')] + if verbose: + print() + print('Run:', ' '.join(cmd)) + subprocess.run(cmd, check=True) + + yield python + + +# True if Python is built with the Py_DEBUG macro defined: if +# Python is built in debug mode (./configure --with-pydebug). +Py_DEBUG = hasattr(sys, 'gettotalrefcount') + + +def late_deletion(obj): + """ + Keep a Python alive as long as possible. + + Create a reference cycle and store the cycle in an object deleted late in + Python finalization. Try to keep the object alive until the very last + garbage collection. + + The function keeps a strong reference by design. It should be called in a + subprocess to not mark a test as "leaking a reference". + """ + + # Late CPython finalization: + # - finalize_interp_clear() + # - _PyInterpreterState_Clear(): Clear PyInterpreterState members + # (ex: codec_search_path, before_forkers) + # - clear os.register_at_fork() callbacks + # - clear codecs.register() callbacks + + ref_cycle = [obj] + ref_cycle.append(ref_cycle) + + # Store a reference in PyInterpreterState.codec_search_path + import codecs + def search_func(encoding): + return None + search_func.reference = ref_cycle + codecs.register(search_func) + + if hasattr(os, 'register_at_fork'): + # Store a reference in PyInterpreterState.before_forkers + def atfork_func(): + pass + atfork_func.reference = ref_cycle + os.register_at_fork(before=atfork_func) + + +def busy_retry(timeout, err_msg=None, /, *, error=True): + """ + Run the loop body until "break" stops the loop. + + After *timeout* seconds, raise an AssertionError if *error* is true, + or just stop if *error is false. + + Example: + + for _ in support.busy_retry(support.SHORT_TIMEOUT): + if check(): + break + + Example of error=False usage: + + for _ in support.busy_retry(support.SHORT_TIMEOUT, error=False): + if check(): + break + else: + raise RuntimeError('my custom error') + + """ + if timeout <= 0: + raise ValueError("timeout must be greater than zero") + + start_time = time.monotonic() + deadline = start_time + timeout + + while True: + yield + + if time.monotonic() >= deadline: + break + + if error: + dt = time.monotonic() - start_time + msg = f"timeout ({dt:.1f} seconds)" + if err_msg: + msg = f"{msg}: {err_msg}" + raise AssertionError(msg) + + +def sleeping_retry(timeout, err_msg=None, /, + *, init_delay=0.010, max_delay=1.0, error=True): + """ + Wait strategy that applies exponential backoff. + + Run the loop body until "break" stops the loop. Sleep at each loop + iteration, but not at the first iteration. The sleep delay is doubled at + each iteration (up to *max_delay* seconds). + + See busy_retry() documentation for the parameters usage. + + Example raising an exception after SHORT_TIMEOUT seconds: + + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): + if check(): + break + + Example of error=False usage: + + for _ in support.sleeping_retry(support.SHORT_TIMEOUT, error=False): + if check(): + break + else: + raise RuntimeError('my custom error') + """ + + delay = init_delay + for _ in busy_retry(timeout, err_msg, error=error): + yield + + time.sleep(delay) + delay = min(delay * 2, max_delay) + + @contextlib.contextmanager def adjust_int_max_str_digits(max_digits): """Temporarily change the integer string conversion length limit.""" @@ -2239,3 +2536,13 @@ def adjust_int_max_str_digits(max_digits): yield finally: sys.set_int_max_str_digits(current) + +#For recursion tests, easily exceeds default recursion limit +EXCEEDS_RECURSION_LIMIT = 5000 + +# The default C recursion limit (from Include/cpython/pystate.h). +C_RECURSION_LIMIT = 1500 + +#Windows doesn't have os.uname() but it doesn't support s390x. +skip_on_s390x = unittest.skipIf(hasattr(os, 'uname') and os.uname().machine == 's390x', + 'skipped on s390x') diff --git a/Lib/test/support/bytecode_helper.py b/Lib/test/support/bytecode_helper.py index 471d4a68f9..388d126677 100644 --- a/Lib/test/support/bytecode_helper.py +++ b/Lib/test/support/bytecode_helper.py @@ -3,6 +3,7 @@ import unittest import dis import io +from _testinternalcapi import compiler_codegen, optimize_cfg, assemble_code_object _UNSPECIFIED = object() @@ -16,6 +17,7 @@ def get_disassembly_as_string(self, co): def assertInBytecode(self, x, opname, argval=_UNSPECIFIED): """Returns instr if opname is found, otherwise throws AssertionError""" + self.assertIn(opname, dis.opmap) for instr in dis.get_instructions(x): if instr.opname == opname: if argval is _UNSPECIFIED or instr.argval == argval: @@ -30,6 +32,7 @@ def assertInBytecode(self, x, opname, argval=_UNSPECIFIED): def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED): """Throws AssertionError if opname is found""" + self.assertIn(opname, dis.opmap) for instr in dis.get_instructions(x): if instr.opname == opname: disassembly = self.get_disassembly_as_string(x) @@ -40,3 +43,101 @@ def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED): msg = '(%s,%r) occurs in bytecode:\n%s' msg = msg % (opname, argval, disassembly) self.fail(msg) + +class CompilationStepTestCase(unittest.TestCase): + + HAS_ARG = set(dis.hasarg) + HAS_TARGET = set(dis.hasjrel + dis.hasjabs + dis.hasexc) + HAS_ARG_OR_TARGET = HAS_ARG.union(HAS_TARGET) + + class Label: + pass + + def assertInstructionsMatch(self, actual_, expected_): + # get two lists where each entry is a label or + # an instruction tuple. Normalize the labels to the + # instruction count of the target, and compare the lists. + + self.assertIsInstance(actual_, list) + self.assertIsInstance(expected_, list) + + actual = self.normalize_insts(actual_) + expected = self.normalize_insts(expected_) + self.assertEqual(len(actual), len(expected)) + + # compare instructions + for act, exp in zip(actual, expected): + if isinstance(act, int): + self.assertEqual(exp, act) + continue + self.assertIsInstance(exp, tuple) + self.assertIsInstance(act, tuple) + # crop comparison to the provided expected values + if len(act) > len(exp): + act = act[:len(exp)] + self.assertEqual(exp, act) + + def resolveAndRemoveLabels(self, insts): + idx = 0 + res = [] + for item in insts: + assert isinstance(item, (self.Label, tuple)) + if isinstance(item, self.Label): + item.value = idx + else: + idx += 1 + res.append(item) + + return res + + def normalize_insts(self, insts): + """ Map labels to instruction index. + Map opcodes to opnames. + """ + insts = self.resolveAndRemoveLabels(insts) + res = [] + for item in insts: + assert isinstance(item, tuple) + opcode, oparg, *loc = item + opcode = dis.opmap.get(opcode, opcode) + if isinstance(oparg, self.Label): + arg = oparg.value + else: + arg = oparg if opcode in self.HAS_ARG else None + opcode = dis.opname[opcode] + res.append((opcode, arg, *loc)) + return res + + def complete_insts_info(self, insts): + # fill in omitted fields in location, and oparg 0 for ops with no arg. + res = [] + for item in insts: + assert isinstance(item, tuple) + inst = list(item) + opcode = dis.opmap[inst[0]] + oparg = inst[1] + loc = inst[2:] + [-1] * (6 - len(inst)) + res.append((opcode, oparg, *loc)) + return res + + +class CodegenTestCase(CompilationStepTestCase): + + def generate_code(self, ast): + insts, _ = compiler_codegen(ast, "my_file.py", 0) + return insts + + +class CfgOptimizationTestCase(CompilationStepTestCase): + + def get_optimized(self, insts, consts, nlocals=0): + insts = self.normalize_insts(insts) + insts = self.complete_insts_info(insts) + insts = optimize_cfg(insts, consts, nlocals) + return insts, consts + +class AssemblerTestCase(CompilationStepTestCase): + + def get_code_object(self, filename, insts, metadata): + co = assemble_code_object(filename, insts, metadata) + return co diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 5201dc84cf..67f18e530e 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -105,6 +105,26 @@ def frozen_modules(enabled=True): _imp._override_frozen_modules_for_tests(0) +@contextlib.contextmanager +def multi_interp_extensions_check(enabled=True): + """Force legacy modules to be allowed in subinterpreters (or not). + + ("legacy" == single-phase init) + + This only applies to modules that haven't been imported yet. + It overrides the PyInterpreterConfig.check_multi_interp_extensions + setting (see support.run_in_subinterp_with_config() and + _xxsubinterpreters.create()). + + Also see importlib.utils.allowing_all_extensions(). + """ + old = _imp._override_multi_interp_extensions_check(1 if enabled else -1) + try: + yield + finally: + _imp._override_multi_interp_extensions_check(old) + + def import_fresh_module(name, fresh=(), blocked=(), *, deprecated=False, usefrozen=False, @@ -246,3 +266,11 @@ def modules_cleanup(oldmodules): # do currently). Implicitly imported *real* modules should be left alone # (see issue 10556). sys.modules.update(oldmodules) + + +def mock_register_at_fork(func): + # bpo-30599: Mock os.register_at_fork() when importing the random module, + # since this function doesn't allow to unregister callbacks and would leak + # memory. + from unittest import mock + return mock.patch('os.register_at_fork', create=True)(func) diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py index 2935708f9d..5c484d1170 100644 --- a/Lib/test/support/interpreters.py +++ b/Lib/test/support/interpreters.py @@ -2,11 +2,12 @@ import time import _xxsubinterpreters as _interpreters +import _xxinterpchannels as _channels # aliases: -from _xxsubinterpreters import ( +from _xxsubinterpreters import is_shareable, RunFailedError +from _xxinterpchannels import ( ChannelError, ChannelNotFoundError, ChannelEmptyError, - is_shareable, ) @@ -102,7 +103,7 @@ def create_channel(): The channel may be used to pass data safely between interpreters. """ - cid = _interpreters.channel_create() + cid = _channels.create() recv, send = RecvChannel(cid), SendChannel(cid) return recv, send @@ -110,14 +111,14 @@ def create_channel(): def list_all_channels(): """Return a list of (recv, send) for all open channels.""" return [(RecvChannel(cid), SendChannel(cid)) - for cid in _interpreters.channel_list_all()] + for cid in _channels.list_all()] class _ChannelEnd: """The base class for RecvChannel and SendChannel.""" def __init__(self, id): - if not isinstance(id, (int, _interpreters.ChannelID)): + if not isinstance(id, (int, _channels.ChannelID)): raise TypeError(f'id must be an int, got {id!r}') self._id = id @@ -152,10 +153,10 @@ def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds This blocks until an object has been sent, if none have been sent already. """ - obj = _interpreters.channel_recv(self._id, _sentinel) + obj = _channels.recv(self._id, _sentinel) while obj is _sentinel: time.sleep(_delay) - obj = _interpreters.channel_recv(self._id, _sentinel) + obj = _channels.recv(self._id, _sentinel) return obj def recv_nowait(self, default=_NOT_SET): @@ -166,9 +167,9 @@ def recv_nowait(self, default=_NOT_SET): is the same as recv(). """ if default is _NOT_SET: - return _interpreters.channel_recv(self._id) + return _channels.recv(self._id) else: - return _interpreters.channel_recv(self._id, default) + return _channels.recv(self._id, default) class SendChannel(_ChannelEnd): @@ -179,7 +180,7 @@ def send(self, obj): This blocks until the object is received. """ - _interpreters.channel_send(self._id, obj) + _channels.send(self._id, obj) # XXX We are missing a low-level channel_send_wait(). # See bpo-32604 and gh-19829. # Until that shows up we fake it: @@ -194,4 +195,4 @@ def send_nowait(self, obj): # XXX Note that at the moment channel_send() only ever returns # None. This should be fixed when channel_send_wait() is added. # See bpo-32604 and gh-19829. - return _interpreters.channel_send(self._id, obj) + return _channels.send(self._id, obj) diff --git a/Lib/test/support/os_helper.py b/Lib/test/support/os_helper.py index f599cc7521..821a4b1ffd 100644 --- a/Lib/test/support/os_helper.py +++ b/Lib/test/support/os_helper.py @@ -4,6 +4,7 @@ import os import re import stat +import string import sys import time import unittest @@ -11,11 +12,7 @@ # Filename used for testing -if os.name == 'java': - # Jython disallows @ in module names - TESTFN_ASCII = '$test' -else: - TESTFN_ASCII = '@test' +TESTFN_ASCII = '@test' # Disambiguate TESTFN for parallel testing, while letting it remain a valid # module name. @@ -141,6 +138,11 @@ try: name.decode(sys.getfilesystemencoding()) except UnicodeDecodeError: + try: + name.decode(sys.getfilesystemencoding(), + sys.getfilesystemencodeerrors()) + except UnicodeDecodeError: + continue TESTFN_UNDECODABLE = os.fsencode(TESTFN_ASCII) + name break @@ -567,7 +569,7 @@ def fs_is_case_insensitive(directory): class FakePath: - """Simple implementing of the path protocol. + """Simple implementation of the path protocol. """ def __init__(self, path): self.path = path @@ -715,3 +717,37 @@ def __exit__(self, *ignore_exc): else: self._environ[k] = v os.environ = self._environ + + +try: + import ctypes + kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + + ERROR_FILE_NOT_FOUND = 2 + DDD_REMOVE_DEFINITION = 2 + DDD_EXACT_MATCH_ON_REMOVE = 4 + DDD_NO_BROADCAST_SYSTEM = 8 +except (ImportError, AttributeError): + def subst_drive(path): + raise unittest.SkipTest('ctypes or kernel32 is not available') +else: + @contextlib.contextmanager + def subst_drive(path): + """Temporarily yield a substitute drive for a given path.""" + for c in reversed(string.ascii_uppercase): + drive = f'{c}:' + if (not kernel32.QueryDosDeviceW(drive, None, 0) and + ctypes.get_last_error() == ERROR_FILE_NOT_FOUND): + break + else: + raise unittest.SkipTest('no available logical drive') + if not kernel32.DefineDosDeviceW( + DDD_NO_BROADCAST_SYSTEM, drive, path): + raise ctypes.WinError(ctypes.get_last_error()) + try: + yield drive + finally: + if not kernel32.DefineDosDeviceW( + DDD_REMOVE_DEFINITION | DDD_EXACT_MATCH_ON_REMOVE, + drive, path): + raise ctypes.WinError(ctypes.get_last_error()) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 42b2a93398..e85d912c77 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -1,8 +1,11 @@ import contextlib import errno +import os.path import socket -import unittest import sys +import subprocess +import tempfile +import unittest from .. import support from . import warnings_helper @@ -61,7 +64,7 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): http://bugs.python.org/issue2550 for more info. The following site also has a very thorough description about the implications of both REUSEADDR and EXCLUSIVEADDRUSE on Windows: - http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx) + https://learn.microsoft.com/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse XXX: although this approach is a vast improvement on previous attempts to elicit unused ports, it rests heavily on the assumption that the ephemeral @@ -270,3 +273,73 @@ def filter_error(err): # __cause__ or __context__? finally: socket.setdefaulttimeout(old_timeout) + + +def create_unix_domain_name(): + """ + Create a UNIX domain name: socket.bind() argument of a AF_UNIX socket. + + Return a path relative to the current directory to get a short path + (around 27 ASCII characters). + """ + return tempfile.mktemp(prefix="test_python_", suffix='.sock', + dir=os.path.curdir) + + +# consider that sysctl values should not change while tests are running +_sysctl_cache = {} + +def _get_sysctl(name): + """Get a sysctl value as an integer.""" + try: + return _sysctl_cache[name] + except KeyError: + pass + + # At least Linux and FreeBSD support the "-n" option + cmd = ['sysctl', '-n', name] + proc = subprocess.run(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + if proc.returncode: + support.print_warning(f'{' '.join(cmd)!r} command failed with ' + f'exit code {proc.returncode}') + # cache the error to only log the warning once + _sysctl_cache[name] = None + return None + output = proc.stdout + + # Parse '0\n' to get '0' + try: + value = int(output.strip()) + except Exception as exc: + support.print_warning(f'Failed to parse {' '.join(cmd)!r} ' + f'command output {output!r}: {exc!r}') + # cache the error to only log the warning once + _sysctl_cache[name] = None + return None + + _sysctl_cache[name] = value + return value + + +def tcp_blackhole(): + if not sys.platform.startswith('freebsd'): + return False + + # gh-109015: test if FreeBSD TCP blackhole is enabled + value = _get_sysctl('net.inet.tcp.blackhole') + if value is None: + # don't skip if we fail to get the sysctl value + return False + return (value != 0) + + +def skip_if_tcp_blackhole(test): + """Decorator skipping test if TCP blackhole is enabled.""" + skip_if = unittest.skipIf( + tcp_blackhole(), + "TCP blackhole is enabled (sysctl net.inet.tcp.blackhole)" + ) + return skip_if(test) diff --git a/Lib/test/support/testresult.py b/Lib/test/support/testresult.py index 2cd1366cd8..de23fdd59d 100644 --- a/Lib/test/support/testresult.py +++ b/Lib/test/support/testresult.py @@ -8,6 +8,7 @@ import time import traceback import unittest +from test import support class RegressionTestResult(unittest.TextTestResult): USE_XML = False @@ -18,10 +19,13 @@ def __init__(self, stream, descriptions, verbosity): self.buffer = True if self.USE_XML: from xml.etree import ElementTree as ET - from datetime import datetime + from datetime import datetime, UTC self.__ET = ET self.__suite = ET.Element('testsuite') - self.__suite.set('start', datetime.utcnow().isoformat(' ')) + self.__suite.set('start', + datetime.now(UTC) + .replace(tzinfo=None) + .isoformat(' ')) self.__e = None self.__start_time = None @@ -109,6 +113,8 @@ def addExpectedFailure(self, test, err): def addFailure(self, test, err): self._add_result(test, True, failure=self.__makeErrorDict(*err)) super().addFailure(test, err) + if support.failfast: + self.stop() def addSkip(self, test, reason): self._add_result(test, skipped=reason) diff --git a/Lib/test/support/threading_helper.py b/Lib/test/support/threading_helper.py index 26cbc6f4d2..7f16050f32 100644 --- a/Lib/test/support/threading_helper.py +++ b/Lib/test/support/threading_helper.py @@ -88,19 +88,17 @@ def wait_threads_exit(timeout=None): yield finally: start_time = time.monotonic() - deadline = start_time + timeout - while True: + for _ in support.sleeping_retry(timeout, error=False): + support.gc_collect() count = _thread._count() if count <= old_count: break - if time.monotonic() > deadline: - dt = time.monotonic() - start_time - msg = (f"wait_threads() failed to cleanup {count - old_count} " - f"threads after {dt:.1f} seconds " - f"(count: {count}, old count: {old_count})") - raise AssertionError(msg) - time.sleep(0.010) - support.gc_collect() + else: + dt = time.monotonic() - start_time + msg = (f"wait_threads() failed to cleanup {count - old_count} " + f"threads after {dt:.1f} seconds " + f"(count: {count}, old count: {old_count})") + raise AssertionError(msg) def join_thread(thread, timeout=None): @@ -117,7 +115,11 @@ def join_thread(thread, timeout=None): @contextlib.contextmanager def start_threads(threads, unlock=None): - import faulthandler + try: + import faulthandler + except ImportError: + # It isn't supported on subinterpreters yet. + faulthandler = None threads = list(threads) started = [] try: @@ -149,7 +151,8 @@ def start_threads(threads, unlock=None): finally: started = [t for t in started if t.is_alive()] if started: - faulthandler.dump_traceback(sys.stdout) + if faulthandler is not None: + faulthandler.dump_traceback(sys.stdout) raise AssertionError('Unable to join %d threads' % len(started)) diff --git a/Lib/test/support/warnings_helper.py b/Lib/test/support/warnings_helper.py index 28e96f88b2..c1bf056230 100644 --- a/Lib/test/support/warnings_helper.py +++ b/Lib/test/support/warnings_helper.py @@ -44,7 +44,7 @@ def check_syntax_warning(testcase, statement, errtext='', def ignore_warnings(*, category): - """Decorator to suppress deprecation warnings. + """Decorator to suppress warnings. Use of context managers to hide warnings make diffs more noisy and tools like 'git blame' less useful. diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 6ad272697b..75f9be1b9c 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -9,7 +9,6 @@ import sys import tempfile import textwrap -import time import unittest import warnings @@ -31,7 +30,7 @@ def setUpClass(cls): "test.support.warnings_helper", like=".*used in test_support.*" ) cls._test_support_token = support.ignore_deprecations_from( - "test.test_support", like=".*You should NOT be seeing this.*" + __name__, like=".*You should NOT be seeing this.*" ) assert len(warnings.filters) == orig_filter_len + 2 @@ -464,18 +463,12 @@ def test_reap_children(self): # child process: do nothing, just exit os._exit(0) - t0 = time.monotonic() - deadline = time.monotonic() + support.SHORT_TIMEOUT - was_altered = support.environment_altered try: support.environment_altered = False stderr = io.StringIO() - while True: - if time.monotonic() > deadline: - self.fail("timeout") - + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): with support.swap_attr(support.print_warning, 'orig_stderr', stderr): support.reap_children() @@ -484,9 +477,6 @@ def test_reap_children(self): if support.environment_altered: break - # loop until the child process completed - time.sleep(0.100) - msg = "Warning -- reap_children() reaped child process %s" % pid self.assertIn(msg, stderr.getvalue()) self.assertTrue(support.environment_altered) @@ -513,6 +503,7 @@ def check_options(self, args, func, expected=None): self.assertEqual(proc.stdout.rstrip(), repr(expected)) self.assertEqual(proc.returncode, 0) + @support.requires_resource('cpu') def test_args_from_interpreter_flags(self): # Test test.support.args_from_interpreter_flags() for opts in ( @@ -702,6 +693,83 @@ def test_has_strftime_extensions(self): else: self.assertTrue(support.has_strftime_extensions) + def test_get_recursion_depth(self): + # test support.get_recursion_depth() + code = textwrap.dedent(""" + from test import support + import sys + + def check(cond): + if not cond: + raise AssertionError("test failed") + + # depth 1 + check(support.get_recursion_depth() == 1) + + # depth 2 + def test_func(): + check(support.get_recursion_depth() == 2) + test_func() + + def test_recursive(depth, limit): + if depth >= limit: + # cannot call get_recursion_depth() at this depth, + # it can raise RecursionError + return + get_depth = support.get_recursion_depth() + print(f"test_recursive: {depth}/{limit}: " + f"get_recursion_depth() says {get_depth}") + check(get_depth == depth) + test_recursive(depth + 1, limit) + + # depth up to 25 + with support.infinite_recursion(max_depth=25): + limit = sys.getrecursionlimit() + print(f"test with sys.getrecursionlimit()={limit}") + test_recursive(2, limit) + + # depth up to 500 + with support.infinite_recursion(max_depth=500): + limit = sys.getrecursionlimit() + print(f"test with sys.getrecursionlimit()={limit}") + test_recursive(2, limit) + """) + script_helper.assert_python_ok("-c", code) + + def test_recursion(self): + # Test infinite_recursion() and get_recursion_available() functions. + def recursive_function(depth): + if depth: + recursive_function(depth - 1) + + for max_depth in (5, 25, 250): + with support.infinite_recursion(max_depth): + available = support.get_recursion_available() + + # Recursion up to 'available' additional frames should be OK. + recursive_function(available) + + # Recursion up to 'available+1' additional frames must raise + # RecursionError. Avoid self.assertRaises(RecursionError) which + # can consume more than 3 frames and so raises RecursionError. + try: + recursive_function(available + 1) + except RecursionError: + pass + else: + self.fail("RecursionError was not raised") + + # Test the bare minimumum: max_depth=3 + with support.infinite_recursion(3): + try: + recursive_function(3) + except RecursionError: + pass + else: + self.fail("RecursionError was not raised") + + #self.assertEqual(available, 2) + # XXX -follows a list of untested API # make_legacy_pyc # is_resource_enabled From d765fb6f727e98695e8276785173b75cc4f4e8c6 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 11:16:06 -0700 Subject: [PATCH 06/19] Fix unsupported parser feature --- Lib/test/support/socket_helper.py | 4 ++-- extra_tests/snippets/builtin_none.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index e85d912c77..d9c087c251 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -303,7 +303,7 @@ def _get_sysctl(name): stderr=subprocess.STDOUT, text=True) if proc.returncode: - support.print_warning(f'{' '.join(cmd)!r} command failed with ' + support.print_warning(f'{" ".join(cmd)!r} command failed with ' f'exit code {proc.returncode}') # cache the error to only log the warning once _sysctl_cache[name] = None @@ -314,7 +314,7 @@ def _get_sysctl(name): try: value = int(output.strip()) except Exception as exc: - support.print_warning(f'Failed to parse {' '.join(cmd)!r} ' + support.print_warning(f'Failed to parse {" ".join(cmd)!r} ' f'command output {output!r}: {exc!r}') # cache the error to only log the warning once _sysctl_cache[name] = None diff --git a/extra_tests/snippets/builtin_none.py b/extra_tests/snippets/builtin_none.py index b0080a9d25..c75f04ea73 100644 --- a/extra_tests/snippets/builtin_none.py +++ b/extra_tests/snippets/builtin_none.py @@ -22,5 +22,4 @@ def none2(): assert None.__eq__(3) is NotImplemented assert None.__ne__(3) is NotImplemented assert None.__eq__(None) is True -assert None.__ne__(None) is False - +# assert None.__ne__(None) is False # changed in 3.12 From d66e92a3c251090cb5a996bd2083314fe5c881df Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 11:19:22 -0700 Subject: [PATCH 07/19] mark failing test --- Lib/test/test_support.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 75f9be1b9c..673160c20b 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -693,6 +693,7 @@ def test_has_strftime_extensions(self): else: self.assertTrue(support.has_strftime_extensions) + @unittest.expectedFailure def test_get_recursion_depth(self): # test support.get_recursion_depth() code = textwrap.dedent(""" From 7ff899bb375aa9a13623ea8373934f09eb13f4cd Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Wed, 4 Oct 2023 01:09:10 +0900 Subject: [PATCH 08/19] Update functools from Python 3.12 --- Lib/functools.py | 183 +++++++++++++++++++++++++++-------------------- 1 file changed, 104 insertions(+), 79 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 8decc874e1..2ae4290f98 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -10,9 +10,9 @@ # See C source code for _functools credits/copyright __all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', - 'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', - 'partialmethod', 'singledispatch', 'singledispatchmethod', - "cached_property"] + 'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce', + 'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod', + 'cached_property'] from abc import get_cache_token from collections import namedtuple @@ -30,7 +30,7 @@ # wrapper functions that can handle naive introspection WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', - '__annotations__') + '__annotations__', '__type_params__') WRAPPER_UPDATES = ('__dict__',) def update_wrapper(wrapper, wrapped, @@ -86,82 +86,86 @@ def wraps(wrapped, # infinite recursion that could occur when the operator dispatch logic # detects a NotImplemented result and then calls a reflected method. -def _gt_from_lt(self, other, NotImplemented=NotImplemented): +def _gt_from_lt(self, other): 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' - op_result = self.__lt__(other) + op_result = type(self).__lt__(self, other) if op_result is NotImplemented: return op_result return not op_result and self != other -def _le_from_lt(self, other, NotImplemented=NotImplemented): +def _le_from_lt(self, other): 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' - op_result = self.__lt__(other) + op_result = type(self).__lt__(self, other) + if op_result is NotImplemented: + return op_result return op_result or self == other -def _ge_from_lt(self, other, NotImplemented=NotImplemented): +def _ge_from_lt(self, other): 'Return a >= b. Computed by @total_ordering from (not a < b).' - op_result = self.__lt__(other) + op_result = type(self).__lt__(self, other) if op_result is NotImplemented: return op_result return not op_result -def _ge_from_le(self, other, NotImplemented=NotImplemented): +def _ge_from_le(self, other): 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' - op_result = self.__le__(other) + op_result = type(self).__le__(self, other) if op_result is NotImplemented: return op_result return not op_result or self == other -def _lt_from_le(self, other, NotImplemented=NotImplemented): +def _lt_from_le(self, other): 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' - op_result = self.__le__(other) + op_result = type(self).__le__(self, other) if op_result is NotImplemented: return op_result return op_result and self != other -def _gt_from_le(self, other, NotImplemented=NotImplemented): +def _gt_from_le(self, other): 'Return a > b. Computed by @total_ordering from (not a <= b).' - op_result = self.__le__(other) + op_result = type(self).__le__(self, other) if op_result is NotImplemented: return op_result return not op_result -def _lt_from_gt(self, other, NotImplemented=NotImplemented): +def _lt_from_gt(self, other): 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' - op_result = self.__gt__(other) + op_result = type(self).__gt__(self, other) if op_result is NotImplemented: return op_result return not op_result and self != other -def _ge_from_gt(self, other, NotImplemented=NotImplemented): +def _ge_from_gt(self, other): 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' - op_result = self.__gt__(other) + op_result = type(self).__gt__(self, other) + if op_result is NotImplemented: + return op_result return op_result or self == other -def _le_from_gt(self, other, NotImplemented=NotImplemented): +def _le_from_gt(self, other): 'Return a <= b. Computed by @total_ordering from (not a > b).' - op_result = self.__gt__(other) + op_result = type(self).__gt__(self, other) if op_result is NotImplemented: return op_result return not op_result -def _le_from_ge(self, other, NotImplemented=NotImplemented): +def _le_from_ge(self, other): 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' - op_result = self.__ge__(other) + op_result = type(self).__ge__(self, other) if op_result is NotImplemented: return op_result return not op_result or self == other -def _gt_from_ge(self, other, NotImplemented=NotImplemented): +def _gt_from_ge(self, other): 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' - op_result = self.__ge__(other) + op_result = type(self).__ge__(self, other) if op_result is NotImplemented: return op_result return op_result and self != other -def _lt_from_ge(self, other, NotImplemented=NotImplemented): +def _lt_from_ge(self, other): 'Return a < b. Computed by @total_ordering from (not a >= b).' - op_result = self.__ge__(other) + op_result = type(self).__ge__(self, other) if op_result is NotImplemented: return op_result return not op_result @@ -232,14 +236,14 @@ def __ge__(self, other): def reduce(function, sequence, initial=_initial_missing): """ - reduce(function, sequence[, initial]) -> value + reduce(function, iterable[, initial]) -> value - Apply a function of two arguments cumulatively to the items of a sequence, - from left to right, so as to reduce the sequence to a single value. - For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates + Apply a function of two arguments cumulatively to the items of a sequence + or iterable, from left to right, so as to reduce the iterable to a single + value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). If initial is present, it is placed before the items - of the sequence in the calculation, and serves as a default when the - sequence is empty. + of the iterable in the calculation, and serves as a default when the + iterable is empty. """ it = iter(sequence) @@ -248,7 +252,8 @@ def reduce(function, sequence, initial=_initial_missing): try: value = next(it) except StopIteration: - raise TypeError("reduce() of empty sequence with no initial value") from None + raise TypeError( + "reduce() of empty iterable with no initial value") from None else: value = initial @@ -347,23 +352,7 @@ class partialmethod(object): callables as instance methods. """ - def __init__(*args, **keywords): - if len(args) >= 2: - self, func, *args = args - elif not args: - raise TypeError("descriptor '__init__' of partialmethod " - "needs an argument") - elif 'func' in keywords: - func = keywords.pop('func') - self, *args = args - import warnings - warnings.warn("Passing 'func' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - raise TypeError("type 'partialmethod' takes at least one argument, " - "got %d" % (len(args)-1)) - args = tuple(args) - + def __init__(self, func, /, *args, **keywords): if not callable(func) and not hasattr(func, "__get__"): raise TypeError("{!r} is not callable or a descriptor" .format(func)) @@ -381,7 +370,6 @@ def __init__(*args, **keywords): self.func = func self.args = args self.keywords = keywords - __init__.__text_signature__ = '($self, func, /, *args, **keywords)' def __repr__(self): args = ", ".join(map(repr, self.args)) @@ -427,6 +415,7 @@ def __isabstractmethod__(self): __class_getitem__ = classmethod(GenericAlias) + # Helper functions def _unwrap_partial(func): @@ -503,7 +492,7 @@ def lru_cache(maxsize=128, typed=False): with f.cache_info(). Clear the cache and statistics with f.cache_clear(). Access the underlying function with f.__wrapped__. - See: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) + See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) """ @@ -520,6 +509,7 @@ def lru_cache(maxsize=128, typed=False): # The user_function was passed in directly via the maxsize argument user_function, maxsize = maxsize, 128 wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) + wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed} return update_wrapper(wrapper, user_function) elif maxsize is not None: raise TypeError( @@ -527,6 +517,7 @@ def lru_cache(maxsize=128, typed=False): def decorating_function(user_function): wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) + wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed} return update_wrapper(wrapper, user_function) return decorating_function @@ -653,6 +644,15 @@ def cache_clear(): pass +################################################################################ +### cache -- simplified access to the infinity cache +################################################################################ + +def cache(user_function, /): + 'Simple lightweight unbounded cache. Sometimes called "memoize".' + return lru_cache(maxsize=None)(user_function) + + ################################################################################ ### singledispatch() - single-dispatch generic function decorator ################################################################################ @@ -660,7 +660,7 @@ def cache_clear(): def _c3_merge(sequences): """Merges MROs in *sequences* to a single MRO using the C3 algorithm. - Adapted from http://www.python.org/download/releases/2.3/mro/. + Adapted from https://www.python.org/download/releases/2.3/mro/. """ result = [] @@ -740,6 +740,7 @@ def _compose_mro(cls, types): # Remove entries which are already present in the __mro__ or unrelated. def is_related(typ): return (typ not in bases and hasattr(typ, '__mro__') + and not isinstance(typ, GenericAlias) and issubclass(cls, typ)) types = [n for n in types if is_related(n)] # Remove entries which are strict bases of other entries (they will end up @@ -837,6 +838,17 @@ def dispatch(cls): dispatch_cache[cls] = impl return impl + def _is_union_type(cls): + from typing import get_origin, Union + return get_origin(cls) in {Union, types.UnionType} + + def _is_valid_dispatch_type(cls): + if isinstance(cls, type): + return True + from typing import get_args + return (_is_union_type(cls) and + all(isinstance(arg, type) for arg in get_args(cls))) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -844,9 +856,15 @@ def register(cls, func=None): """ nonlocal cache_token - if func is None: - if isinstance(cls, type): + if _is_valid_dispatch_type(cls): + if func is None: return lambda f: register(cls, f) + else: + if func is not None: + raise TypeError( + f"Invalid first argument to `register()`. " + f"{cls!r} is not a class or union type." + ) ann = getattr(cls, '__annotations__', {}) if not ann: raise TypeError( @@ -859,12 +877,25 @@ def register(cls, func=None): # only import typing if annotation parsing is necessary from typing import get_type_hints argname, cls = next(iter(get_type_hints(func).items())) - if not isinstance(cls, type): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - registry[cls] = func + if not _is_valid_dispatch_type(cls): + if _is_union_type(cls): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} not all arguments are classes." + ) + else: + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is not a class." + ) + + if _is_union_type(cls): + from typing import get_args + + for arg in get_args(cls): + registry[arg] = func + else: + registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() @@ -925,18 +956,16 @@ def __isabstractmethod__(self): ################################################################################ -### cached_property() - computed once per instance, cached as attribute +### cached_property() - property result cached as instance attribute ################################################################################ _NOT_FOUND = object() - class cached_property: def __init__(self, func): self.func = func self.attrname = None self.__doc__ = func.__doc__ - self.lock = RLock() def __set_name__(self, owner, name): if self.attrname is None: @@ -963,19 +992,15 @@ def __get__(self, instance, owner=None): raise TypeError(msg) from None val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: - with self.lock: - # check if another thread filled cache while we awaited lock - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - val = self.func(instance) - try: - cache[self.attrname] = val - except TypeError: - msg = ( - f"The '__dict__' attribute on {type(instance).__name__!r} instance " - f"does not support item assignment for caching {self.attrname!r} property." - ) - raise TypeError(msg) from None + val = self.func(instance) + try: + cache[self.attrname] = val + except TypeError: + msg = ( + f"The '__dict__' attribute on {type(instance).__name__!r} instance " + f"does not support item assignment for caching {self.attrname!r} property." + ) + raise TypeError(msg) from None return val __class_getitem__ = classmethod(GenericAlias) From 802b504489173df4b6dd19a0aaf89b4c228f4cf2 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 4 Oct 2023 01:41:24 +0900 Subject: [PATCH 09/19] manual type annotation --- Lib/importlib/resources/_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/importlib/resources/_common.py b/Lib/importlib/resources/_common.py index a390253534..b402e05116 100644 --- a/Lib/importlib/resources/_common.py +++ b/Lib/importlib/resources/_common.py @@ -77,12 +77,12 @@ def resolve(cand: Optional[Anchor]) -> types.ModuleType: return cast(types.ModuleType, cand) -@resolve.register +@resolve.register(str) # TODO: RUSTPYTHON; manual type annotation def _(cand: str) -> types.ModuleType: return importlib.import_module(cand) -@resolve.register +@resolve.register(type(None)) # TODO: RUSTPYTHON; manual type annotation def _(cand: None) -> types.ModuleType: return resolve(_infer_caller().f_globals['__name__']) From cfda063fb4bc1aecb0a726893928912eab01b5c7 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 12:14:58 -0700 Subject: [PATCH 10/19] slice behavior changed in 3.12 --- extra_tests/snippets/builtin_slice.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/extra_tests/snippets/builtin_slice.py b/extra_tests/snippets/builtin_slice.py index 71ab7cbde5..57fb7e21c2 100644 --- a/extra_tests/snippets/builtin_slice.py +++ b/extra_tests/snippets/builtin_slice.py @@ -82,14 +82,15 @@ assert_raises(TypeError, lambda: slice(0) <= 3) assert_raises(TypeError, lambda: slice(0) >= 3) -assert_raises(TypeError, hash, slice(0)) -assert_raises(TypeError, hash, slice(None)) - -def dict_slice(): - d = {} - d[slice(0)] = 3 - -assert_raises(TypeError, dict_slice) +# TODO: slice is hashable in CPython 3.12 +# assert_raises(TypeError, hash, slice(0)) +# assert_raises(TypeError, hash, slice(None)) +# +# def dict_slice(): +# d = {} +# d[slice(0)] = 3 +# +# assert_raises(TypeError, dict_slice) assert slice(None ).indices(10) == (0, 10, 1) assert slice(None, None, 2).indices(10) == (0, 10, 2) From 1332a5d3450675aeb889632b68aae63509bad3f1 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 12:32:54 -0700 Subject: [PATCH 11/19] empty unittest.main returns non-zero --- extra_tests/snippets/syntax_async.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extra_tests/snippets/syntax_async.py b/extra_tests/snippets/syntax_async.py index 011182cce7..953669c2c4 100644 --- a/extra_tests/snippets/syntax_async.py +++ b/extra_tests/snippets/syntax_async.py @@ -128,5 +128,5 @@ async def foo(): foo().send(None) -if __name__ == "__main__": - unittest.main() + if __name__ == "__main__": + unittest.main() From eb5484e21b4eff272d1f47fd6d5ea7917bbc795d Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 17:32:51 -0700 Subject: [PATCH 12/19] test_decimal from CPython 3.12 --- Lib/test/test_decimal.py | 435 +++++++++++++++++++++++++++------------ 1 file changed, 301 insertions(+), 134 deletions(-) diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 47e10bf2a6..ab743aa7a3 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -20,7 +20,7 @@ This test module can be called from command line with one parameter (Arithmetic or Behaviour) to test each part, or without parameter to test both parts. If -you're working through IDLE, you can import this test module and call test_main() +you're working through IDLE, you can import this test module and call test() with the corresponding argument. """ @@ -32,13 +32,14 @@ import unittest import numbers import locale -from test.support import (run_unittest, run_doctest, is_resource_enabled, +from test.support import (is_resource_enabled, requires_IEEE_754, requires_docstrings, requires_legacy_unicode_capi, check_sanitizer) from test.support import (TestFailed, run_with_locale, cpython_only, - darwin_malloc_err_warning) + darwin_malloc_err_warning, is_emscripten) from test.support.import_helper import import_fresh_module +from test.support import threading_helper from test.support import warnings_helper import random import inspect @@ -61,6 +62,7 @@ fractions = {C:cfractions, P:pfractions} sys.modules['decimal'] = orig_sys_decimal +requires_cdecimal = unittest.skipUnless(C, "test requires C version") # Useful Test Constant Signals = { @@ -98,7 +100,7 @@ def assert_signals(cls, context, attr, expected): ] # Tests are built around these assumed context defaults. -# test_main() restores the original context. +# test() restores the original context. ORIGINAL_CONTEXT = { C: C.getcontext().copy() if C else None, P: P.getcontext().copy() @@ -132,7 +134,7 @@ def init(m): EXTRA_FUNCTIONALITY, "test requires regular build") -class IBMTestCases(unittest.TestCase): +class IBMTestCases: """Class which tests the Decimal class against the IBM test cases.""" def setUp(self): @@ -487,14 +489,10 @@ def change_max_exponent(self, exp): def change_clamp(self, clamp): self.context.clamp = clamp -class CIBMTestCases(IBMTestCases): - decimal = C -class PyIBMTestCases(IBMTestCases): - decimal = P # The following classes test the behaviour of Decimal according to PEP 327 -class ExplicitConstructionTest(unittest.TestCase): +class ExplicitConstructionTest: '''Unit tests for Explicit Construction cases of Decimal.''' def test_explicit_empty(self): @@ -589,7 +587,7 @@ def test_explicit_from_string(self): self.assertRaises(InvalidOperation, Decimal, "1_2_\u00003") @cpython_only - @requires_legacy_unicode_capi + @requires_legacy_unicode_capi() @warnings_helper.ignore_warnings(category=DeprecationWarning) def test_from_legacy_strings(self): import _testcapi @@ -839,12 +837,13 @@ def test_unicode_digits(self): for input, expected in test_values.items(): self.assertEqual(str(Decimal(input)), expected) -class CExplicitConstructionTest(ExplicitConstructionTest): +@requires_cdecimal +class CExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase): decimal = C -class PyExplicitConstructionTest(ExplicitConstructionTest): +class PyExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase): decimal = P -class ImplicitConstructionTest(unittest.TestCase): +class ImplicitConstructionTest: '''Unit tests for Implicit Construction cases of Decimal.''' def test_implicit_from_None(self): @@ -921,12 +920,13 @@ def __ne__(self, other): self.assertEqual(eval('Decimal(10)' + sym + 'E()'), '10' + rop + 'str') -class CImplicitConstructionTest(ImplicitConstructionTest): +@requires_cdecimal +class CImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase): decimal = C -class PyImplicitConstructionTest(ImplicitConstructionTest): +class PyImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase): decimal = P -class FormatTest(unittest.TestCase): +class FormatTest: '''Unit tests for the format function.''' def test_formatting(self): Decimal = self.decimal.Decimal @@ -1073,6 +1073,57 @@ def test_formatting(self): (',e', '123456', '1.23456e+5'), (',E', '123456', '1.23456E+5'), + # negative zero: default behavior + ('.1f', '-0', '-0.0'), + ('.1f', '-.0', '-0.0'), + ('.1f', '-.01', '-0.0'), + + # negative zero: z option + ('z.1f', '0.', '0.0'), + ('z6.1f', '0.', ' 0.0'), + ('z6.1f', '-1.', ' -1.0'), + ('z.1f', '-0.', '0.0'), + ('z.1f', '.01', '0.0'), + ('z.1f', '-.01', '0.0'), + ('z.2f', '0.', '0.00'), + ('z.2f', '-0.', '0.00'), + ('z.2f', '.001', '0.00'), + ('z.2f', '-.001', '0.00'), + + ('z.1e', '0.', '0.0e+1'), + ('z.1e', '-0.', '0.0e+1'), + ('z.1E', '0.', '0.0E+1'), + ('z.1E', '-0.', '0.0E+1'), + + ('z.2e', '-0.001', '-1.00e-3'), # tests for mishandled rounding + ('z.2g', '-0.001', '-0.001'), + ('z.2%', '-0.001', '-0.10%'), + + ('zf', '-0.0000', '0.0000'), # non-normalized form is preserved + + ('z.1f', '-00000.000001', '0.0'), + ('z.1f', '-00000.', '0.0'), + ('z.1f', '-.0000000000', '0.0'), + + ('z.2f', '-00000.000001', '0.00'), + ('z.2f', '-00000.', '0.00'), + ('z.2f', '-.0000000000', '0.00'), + + ('z.1f', '.09', '0.1'), + ('z.1f', '-.09', '-0.1'), + + (' z.0f', '-0.', ' 0'), + ('+z.0f', '-0.', '+0'), + ('-z.0f', '-0.', '0'), + (' z.0f', '-1.', '-1'), + ('+z.0f', '-1.', '-1'), + ('-z.0f', '-1.', '-1'), + + ('z>6.1f', '-0.', 'zz-0.0'), + ('z>z6.1f', '-0.', 'zzz0.0'), + ('x>z6.1f', '-0.', 'xxx0.0'), + ('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char + # issue 6850 ('a=-7.0', '0.12345', 'aaaa0.1'), @@ -1087,6 +1138,15 @@ def test_formatting(self): # bytes format argument self.assertRaises(TypeError, Decimal(1).__format__, b'-020') + def test_negative_zero_format_directed_rounding(self): + with self.decimal.localcontext() as ctx: + ctx.rounding = ROUND_CEILING + self.assertEqual(format(self.decimal.Decimal('-0.001'), 'z.2f'), + '0.00') + + def test_negative_zero_bad_format(self): + self.assertRaises(ValueError, format, self.decimal.Decimal('1.23'), 'fz') + def test_n_format(self): Decimal = self.decimal.Decimal @@ -1205,12 +1265,13 @@ def __init__(self, a): a = A.from_float(42) self.assertEqual(self.decimal.Decimal, a.a_type) -class CFormatTest(FormatTest): +@requires_cdecimal +class CFormatTest(FormatTest, unittest.TestCase): decimal = C -class PyFormatTest(FormatTest): +class PyFormatTest(FormatTest, unittest.TestCase): decimal = P -class ArithmeticOperatorsTest(unittest.TestCase): +class ArithmeticOperatorsTest: '''Unit tests for all arithmetic operators, binary and unary.''' def test_addition(self): @@ -1466,14 +1527,17 @@ def test_nan_comparisons(self): equality_ops = operator.eq, operator.ne # results when InvalidOperation is not trapped - for x, y in qnan_pairs + snan_pairs: - for op in order_ops + equality_ops: - got = op(x, y) - expected = True if op is operator.ne else False - self.assertIs(expected, got, - "expected {0!r} for operator.{1}({2!r}, {3!r}); " - "got {4!r}".format( - expected, op.__name__, x, y, got)) + with localcontext() as ctx: + ctx.traps[InvalidOperation] = 0 + + for x, y in qnan_pairs + snan_pairs: + for op in order_ops + equality_ops: + got = op(x, y) + expected = True if op is operator.ne else False + self.assertIs(expected, got, + "expected {0!r} for operator.{1}({2!r}, {3!r}); " + "got {4!r}".format( + expected, op.__name__, x, y, got)) # repeat the above, but this time trap the InvalidOperation with localcontext() as ctx: @@ -1505,9 +1569,10 @@ def test_copy_sign(self): self.assertEqual(Decimal(1).copy_sign(-2), d) self.assertRaises(TypeError, Decimal(1).copy_sign, '-2') -class CArithmeticOperatorsTest(ArithmeticOperatorsTest): +@requires_cdecimal +class CArithmeticOperatorsTest(ArithmeticOperatorsTest, unittest.TestCase): decimal = C -class PyArithmeticOperatorsTest(ArithmeticOperatorsTest): +class PyArithmeticOperatorsTest(ArithmeticOperatorsTest, unittest.TestCase): decimal = P # The following are two functions used to test threading in the next class @@ -1595,7 +1660,9 @@ def thfunc2(cls): for sig in Overflow, Underflow, DivisionByZero, InvalidOperation: cls.assertFalse(thiscontext.flags[sig]) -class ThreadingTest(unittest.TestCase): + +@threading_helper.requires_working_threading() +class ThreadingTest: '''Unit tests for thread local contexts in Decimal.''' # Take care executing this test from IDLE, there's an issue in threading @@ -1640,13 +1707,14 @@ def test_threading(self): DefaultContext.Emin = save_emin -class CThreadingTest(ThreadingTest): +@requires_cdecimal +class CThreadingTest(ThreadingTest, unittest.TestCase): decimal = C -class PyThreadingTest(ThreadingTest): +class PyThreadingTest(ThreadingTest, unittest.TestCase): decimal = P -class UsabilityTest(unittest.TestCase): +class UsabilityTest: '''Unit tests for Usability cases of Decimal.''' def test_comparison_operators(self): @@ -2466,12 +2534,22 @@ def test_conversions_from_int(self): self.assertEqual(Decimal(-12).fma(45, Decimal(67)), Decimal(-12).fma(Decimal(45), Decimal(67))) -class CUsabilityTest(UsabilityTest): +@requires_cdecimal +class CUsabilityTest(UsabilityTest, unittest.TestCase): decimal = C -class PyUsabilityTest(UsabilityTest): +class PyUsabilityTest(UsabilityTest, unittest.TestCase): decimal = P -class PythonAPItests(unittest.TestCase): + def setUp(self): + super().setUp() + self._previous_int_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(7000) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_int_limit) + super().tearDown() + +class PythonAPItests: def test_abc(self): Decimal = self.decimal.Decimal @@ -2549,6 +2627,13 @@ def test_int(self): self.assertRaises(OverflowError, int, Decimal('inf')) self.assertRaises(OverflowError, int, Decimal('-inf')) + @cpython_only + def test_small_ints(self): + Decimal = self.decimal.Decimal + # bpo-46361 + for x in range(-5, 257): + self.assertIs(int(Decimal(x)), x) + def test_trunc(self): Decimal = self.decimal.Decimal @@ -2815,12 +2900,13 @@ def test_exception_hierarchy(self): self.assertTrue(issubclass(decimal.DivisionUndefined, ZeroDivisionError)) self.assertTrue(issubclass(decimal.InvalidContext, InvalidOperation)) -class CPythonAPItests(PythonAPItests): +@requires_cdecimal +class CPythonAPItests(PythonAPItests, unittest.TestCase): decimal = C -class PyPythonAPItests(PythonAPItests): +class PyPythonAPItests(PythonAPItests, unittest.TestCase): decimal = P -class ContextAPItests(unittest.TestCase): +class ContextAPItests: def test_none_args(self): Context = self.decimal.Context @@ -2843,7 +2929,7 @@ def test_none_args(self): Overflow]) @cpython_only - @requires_legacy_unicode_capi + @requires_legacy_unicode_capi() @warnings_helper.ignore_warnings(category=DeprecationWarning) def test_from_legacy_strings(self): import _testcapi @@ -3566,12 +3652,13 @@ def test_to_integral_value(self): self.assertRaises(TypeError, c.to_integral_value, '10') self.assertRaises(TypeError, c.to_integral_value, 10, 'x') -class CContextAPItests(ContextAPItests): +@requires_cdecimal +class CContextAPItests(ContextAPItests, unittest.TestCase): decimal = C -class PyContextAPItests(ContextAPItests): +class PyContextAPItests(ContextAPItests, unittest.TestCase): decimal = P -class ContextWithStatement(unittest.TestCase): +class ContextWithStatement: # Can't do these as docstrings until Python 2.6 # as doctest can't handle __future__ statements @@ -3605,6 +3692,44 @@ def test_localcontextarg(self): self.assertIsNot(new_ctx, set_ctx, 'did not copy the context') self.assertIs(set_ctx, enter_ctx, '__enter__ returned wrong context') + def test_localcontext_kwargs(self): + with self.decimal.localcontext( + prec=10, rounding=ROUND_HALF_DOWN, + Emin=-20, Emax=20, capitals=0, + clamp=1 + ) as ctx: + self.assertEqual(ctx.prec, 10) + self.assertEqual(ctx.rounding, self.decimal.ROUND_HALF_DOWN) + self.assertEqual(ctx.Emin, -20) + self.assertEqual(ctx.Emax, 20) + self.assertEqual(ctx.capitals, 0) + self.assertEqual(ctx.clamp, 1) + + self.assertRaises(TypeError, self.decimal.localcontext, precision=10) + + self.assertRaises(ValueError, self.decimal.localcontext, Emin=1) + self.assertRaises(ValueError, self.decimal.localcontext, Emax=-1) + self.assertRaises(ValueError, self.decimal.localcontext, capitals=2) + self.assertRaises(ValueError, self.decimal.localcontext, clamp=2) + + self.assertRaises(TypeError, self.decimal.localcontext, rounding="") + self.assertRaises(TypeError, self.decimal.localcontext, rounding=1) + + self.assertRaises(TypeError, self.decimal.localcontext, flags="") + self.assertRaises(TypeError, self.decimal.localcontext, traps="") + self.assertRaises(TypeError, self.decimal.localcontext, Emin="") + self.assertRaises(TypeError, self.decimal.localcontext, Emax="") + + def test_local_context_kwargs_does_not_overwrite_existing_argument(self): + ctx = self.decimal.getcontext() + orig_prec = ctx.prec + with self.decimal.localcontext(prec=10) as ctx2: + self.assertEqual(ctx2.prec, 10) + self.assertEqual(ctx.prec, orig_prec) + with self.decimal.localcontext(prec=20) as ctx2: + self.assertEqual(ctx2.prec, 20) + self.assertEqual(ctx.prec, orig_prec) + def test_nested_with_statements(self): # Use a copy of the supplied context in the block Decimal = self.decimal.Decimal @@ -3697,12 +3822,13 @@ def test_with_statements_gc3(self): self.assertEqual(c4.prec, 4) del c4 -class CContextWithStatement(ContextWithStatement): +@requires_cdecimal +class CContextWithStatement(ContextWithStatement, unittest.TestCase): decimal = C -class PyContextWithStatement(ContextWithStatement): +class PyContextWithStatement(ContextWithStatement, unittest.TestCase): decimal = P -class ContextFlags(unittest.TestCase): +class ContextFlags: def test_flags_irrelevant(self): # check that the result (numeric result + flags raised) of an @@ -3969,12 +4095,13 @@ def test_float_operation_default(self): self.assertTrue(context.traps[FloatOperation]) self.assertTrue(context.traps[Inexact]) -class CContextFlags(ContextFlags): +@requires_cdecimal +class CContextFlags(ContextFlags, unittest.TestCase): decimal = C -class PyContextFlags(ContextFlags): +class PyContextFlags(ContextFlags, unittest.TestCase): decimal = P -class SpecialContexts(unittest.TestCase): +class SpecialContexts: """Test the context templates.""" def test_context_templates(self): @@ -4054,12 +4181,13 @@ def test_default_context(self): if ex: raise ex -class CSpecialContexts(SpecialContexts): +@requires_cdecimal +class CSpecialContexts(SpecialContexts, unittest.TestCase): decimal = C -class PySpecialContexts(SpecialContexts): +class PySpecialContexts(SpecialContexts, unittest.TestCase): decimal = P -class ContextInputValidation(unittest.TestCase): +class ContextInputValidation: def test_invalid_context(self): Context = self.decimal.Context @@ -4121,12 +4249,13 @@ def test_invalid_context(self): self.assertRaises(TypeError, Context, flags=(0,1)) self.assertRaises(TypeError, Context, traps=(1,0)) -class CContextInputValidation(ContextInputValidation): +@requires_cdecimal +class CContextInputValidation(ContextInputValidation, unittest.TestCase): decimal = C -class PyContextInputValidation(ContextInputValidation): +class PyContextInputValidation(ContextInputValidation, unittest.TestCase): decimal = P -class ContextSubclassing(unittest.TestCase): +class ContextSubclassing: def test_context_subclassing(self): decimal = self.decimal @@ -4235,12 +4364,14 @@ def __init__(self, prec=None, rounding=None, Emin=None, Emax=None, for signal in OrderedSignals[decimal]: self.assertFalse(c.traps[signal]) -class CContextSubclassing(ContextSubclassing): +@requires_cdecimal +class CContextSubclassing(ContextSubclassing, unittest.TestCase): decimal = C -class PyContextSubclassing(ContextSubclassing): +class PyContextSubclassing(ContextSubclassing, unittest.TestCase): decimal = P @skip_if_extra_functionality +@requires_cdecimal class CheckAttributes(unittest.TestCase): def test_module_attributes(self): @@ -4270,7 +4401,7 @@ def test_decimal_attributes(self): y = [s for s in dir(C.Decimal(9)) if '__' in s or not s.startswith('_')] self.assertEqual(set(x) - set(y), set()) -class Coverage(unittest.TestCase): +class Coverage: def test_adjusted(self): Decimal = self.decimal.Decimal @@ -4527,11 +4658,21 @@ def test_copy(self): y = c.copy_sign(x, 1) self.assertEqual(y, -x) -class CCoverage(Coverage): +@requires_cdecimal +class CCoverage(Coverage, unittest.TestCase): decimal = C -class PyCoverage(Coverage): +class PyCoverage(Coverage, unittest.TestCase): decimal = P + def setUp(self): + super().setUp() + self._previous_int_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(7000) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_int_limit) + super().tearDown() + class PyFunctionality(unittest.TestCase): """Extra functionality in decimal.py""" @@ -4773,6 +4914,7 @@ def test_constants(self): self.assertEqual(C.DecTraps, C.DecErrors|C.DecOverflow|C.DecUnderflow) +@requires_cdecimal class CWhitebox(unittest.TestCase): """Whitebox testing for _decimal""" @@ -5426,6 +5568,7 @@ def test_from_tuple(self): with localcontext() as c: + c.prec = 9 c.traps[InvalidOperation] = True c.traps[Overflow] = True c.traps[Underflow] = True @@ -5510,6 +5653,7 @@ def __abs__(self): # Issue 41540: @unittest.skipIf(sys.platform.startswith("aix"), "AIX: default ulimit: test is flaky because of extreme over-allocation") + @unittest.skipIf(is_emscripten, "Test is unstable on Emscripten") @unittest.skipIf(check_sanitizer(address=True, memory=True), "ASAN/MSAN sanitizer defaults to crashing " "instead of returning NULL for malloc failure.") @@ -5548,8 +5692,38 @@ def test_maxcontext_exact_arith(self): self.assertEqual(Decimal(400) ** -1, Decimal('0.0025')) + def test_c_signaldict_segfault(self): + # See gh-106263 for details. + SignalDict = type(C.Context().flags) + sd = SignalDict() + err_msg = "invalid signal dict" + + with self.assertRaisesRegex(ValueError, err_msg): + len(sd) + + with self.assertRaisesRegex(ValueError, err_msg): + iter(sd) + + with self.assertRaisesRegex(ValueError, err_msg): + repr(sd) + + with self.assertRaisesRegex(ValueError, err_msg): + sd[C.InvalidOperation] = True + + with self.assertRaisesRegex(ValueError, err_msg): + sd[C.InvalidOperation] + + with self.assertRaisesRegex(ValueError, err_msg): + sd == C.Context().flags + + with self.assertRaisesRegex(ValueError, err_msg): + C.Context().flags == sd + + with self.assertRaisesRegex(ValueError, err_msg): + sd.copy() + @requires_docstrings -@unittest.skipUnless(C, "test requires C version") +@requires_cdecimal class SignatureTest(unittest.TestCase): """Function signatures""" @@ -5685,52 +5859,10 @@ def doit(ty): doit('Context') -all_tests = [ - CExplicitConstructionTest, PyExplicitConstructionTest, - CImplicitConstructionTest, PyImplicitConstructionTest, - CFormatTest, PyFormatTest, - CArithmeticOperatorsTest, PyArithmeticOperatorsTest, - CThreadingTest, PyThreadingTest, - CUsabilityTest, PyUsabilityTest, - CPythonAPItests, PyPythonAPItests, - CContextAPItests, PyContextAPItests, - CContextWithStatement, PyContextWithStatement, - CContextFlags, PyContextFlags, - CSpecialContexts, PySpecialContexts, - CContextInputValidation, PyContextInputValidation, - CContextSubclassing, PyContextSubclassing, - CCoverage, PyCoverage, - CFunctionality, PyFunctionality, - CWhitebox, PyWhitebox, - CIBMTestCases, PyIBMTestCases, -] - -# Delete C tests if _decimal.so is not present. -if not C: - all_tests = all_tests[1::2] -else: - all_tests.insert(0, CheckAttributes) - all_tests.insert(1, SignatureTest) - - -def test_main(arith=None, verbose=None, todo_tests=None, debug=None): - """ Execute the tests. - - Runs all arithmetic tests if arith is True or if the "decimal" resource - is enabled in regrtest.py - """ - - init(C) - init(P) - global TEST_ALL, DEBUG - TEST_ALL = arith if arith is not None else is_resource_enabled('decimal') - DEBUG = debug - - if todo_tests is None: - test_classes = all_tests - else: - test_classes = [CIBMTestCases, PyIBMTestCases] - +def load_tests(loader, tests, pattern): + if TODO_TESTS is not None: + # Run only Arithmetic tests + tests = loader.suiteClass() # Dynamically build custom test definition for each file in the test # directory and add the definitions to the DecimalTest class. This # procedure insures that new files do not get skipped. @@ -5738,34 +5870,69 @@ def test_main(arith=None, verbose=None, todo_tests=None, debug=None): if '.decTest' not in filename or filename.startswith("."): continue head, tail = filename.split('.') - if todo_tests is not None and head not in todo_tests: + if TODO_TESTS is not None and head not in TODO_TESTS: continue tester = lambda self, f=filename: self.eval_file(directory + f) - setattr(CIBMTestCases, 'test_' + head, tester) - setattr(PyIBMTestCases, 'test_' + head, tester) + setattr(IBMTestCases, 'test_' + head, tester) del filename, head, tail, tester + for prefix, mod in ('C', C), ('Py', P): + if not mod: + continue + test_class = type(prefix + 'IBMTestCases', + (IBMTestCases, unittest.TestCase), + {'decimal': mod}) + tests.addTest(loader.loadTestsFromTestCase(test_class)) + + if TODO_TESTS is None: + from doctest import DocTestSuite, IGNORE_EXCEPTION_DETAIL + for mod in C, P: + if not mod: + continue + def setUp(slf, mod=mod): + sys.modules['decimal'] = mod + def tearDown(slf): + sys.modules['decimal'] = orig_sys_decimal + optionflags = IGNORE_EXCEPTION_DETAIL if mod is C else 0 + sys.modules['decimal'] = mod + tests.addTest(DocTestSuite(mod, setUp=setUp, tearDown=tearDown, + optionflags=optionflags)) + sys.modules['decimal'] = orig_sys_decimal + return tests + +def setUpModule(): + init(C) + init(P) + global TEST_ALL + TEST_ALL = ARITH if ARITH is not None else is_resource_enabled('decimal') + +def tearDownModule(): + if C: C.setcontext(ORIGINAL_CONTEXT[C]) + P.setcontext(ORIGINAL_CONTEXT[P]) + if not C: + warnings.warn('C tests skipped: no module named _decimal.', + UserWarning) + if not orig_sys_decimal is sys.modules['decimal']: + raise TestFailed("Internal error: unbalanced number of changes to " + "sys.modules['decimal'].") + + +ARITH = None +TEST_ALL = True +TODO_TESTS = None +DEBUG = False + +def test(arith=None, verbose=None, todo_tests=None, debug=None): + """ Execute the tests. + Runs all arithmetic tests if arith is True or if the "decimal" resource + is enabled in regrtest.py + """ - try: - run_unittest(*test_classes) - if todo_tests is None: - from doctest import IGNORE_EXCEPTION_DETAIL - savedecimal = sys.modules['decimal'] - if C: - sys.modules['decimal'] = C - run_doctest(C, verbose, optionflags=IGNORE_EXCEPTION_DETAIL) - sys.modules['decimal'] = P - run_doctest(P, verbose) - sys.modules['decimal'] = savedecimal - finally: - if C: C.setcontext(ORIGINAL_CONTEXT[C]) - P.setcontext(ORIGINAL_CONTEXT[P]) - if not C: - warnings.warn('C tests skipped: no module named _decimal.', - UserWarning) - if not orig_sys_decimal is sys.modules['decimal']: - raise TestFailed("Internal error: unbalanced number of changes to " - "sys.modules['decimal'].") + global ARITH, TODO_TESTS, DEBUG + ARITH = arith + TODO_TESTS = todo_tests + DEBUG = debug + unittest.main(__name__, verbosity=2 if verbose else 1, exit=False, argv=[__name__]) if __name__ == '__main__': @@ -5776,8 +5943,8 @@ def test_main(arith=None, verbose=None, todo_tests=None, debug=None): (opt, args) = p.parse_args() if opt.skip: - test_main(arith=False, verbose=True) + test(arith=False, verbose=True) elif args: - test_main(arith=True, verbose=True, todo_tests=args, debug=opt.debug) + test(arith=True, verbose=True, todo_tests=args, debug=opt.debug) else: - test_main(arith=True, verbose=True) + test(arith=True, verbose=True) From 940d68beb1bc2ab6b4c44b4144618c124b2e4259 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 17:43:12 -0700 Subject: [PATCH 13/19] Mark failing tests --- Lib/test/test_decimal.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index ab743aa7a3..472397e8df 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -928,6 +928,8 @@ class PyImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase): class FormatTest: '''Unit tests for the format function.''' + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_formatting(self): Decimal = self.decimal.Decimal @@ -1138,6 +1140,8 @@ def test_formatting(self): # bytes format argument self.assertRaises(TypeError, Decimal(1).__format__, b'-020') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_negative_zero_format_directed_rounding(self): with self.decimal.localcontext() as ctx: ctx.rounding = ROUND_CEILING @@ -3692,6 +3696,8 @@ def test_localcontextarg(self): self.assertIsNot(new_ctx, set_ctx, 'did not copy the context') self.assertIs(set_ctx, enter_ctx, '__enter__ returned wrong context') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_localcontext_kwargs(self): with self.decimal.localcontext( prec=10, rounding=ROUND_HALF_DOWN, @@ -3720,6 +3726,8 @@ def test_localcontext_kwargs(self): self.assertRaises(TypeError, self.decimal.localcontext, Emin="") self.assertRaises(TypeError, self.decimal.localcontext, Emax="") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_local_context_kwargs_does_not_overwrite_existing_argument(self): ctx = self.decimal.getcontext() orig_prec = ctx.prec From a17fd97cac8cc48de85820caa1d8f93d02eb3143 Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Fri, 20 Oct 2023 17:46:20 -0700 Subject: [PATCH 14/19] Update test_unicode from CPython 3.12 --- Lib/test/test_unicode.py | 878 ++++++++++++--------------------------- 1 file changed, 269 insertions(+), 609 deletions(-) diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 071a2a06c1..17c9f01cd8 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -9,17 +9,22 @@ import codecs import itertools import operator +import pickle import struct import sys import textwrap import unicodedata import unittest import warnings -from test.support import import_helper from test.support import warnings_helper from test import support, string_tests from test.support.script_helper import assert_python_failure +try: + import _testcapi +except ImportError: + _testcapi = None + # Error handling (bad decoder return) def search_function(encoding): def decode1(input, errors="strict"): @@ -89,88 +94,85 @@ def test_literals(self): self.assertNotEqual(r"\u0020", " ") def test_ascii(self): - if not sys.platform.startswith('java'): - # Test basic sanity of repr() - self.assertEqual(ascii('abc'), "'abc'") - self.assertEqual(ascii('ab\\c'), "'ab\\\\c'") - self.assertEqual(ascii('ab\\'), "'ab\\\\'") - self.assertEqual(ascii('\\c'), "'\\\\c'") - self.assertEqual(ascii('\\'), "'\\\\'") - self.assertEqual(ascii('\n'), "'\\n'") - self.assertEqual(ascii('\r'), "'\\r'") - self.assertEqual(ascii('\t'), "'\\t'") - self.assertEqual(ascii('\b'), "'\\x08'") - self.assertEqual(ascii("'\""), """'\\'"'""") - self.assertEqual(ascii("'\""), """'\\'"'""") - self.assertEqual(ascii("'"), '''"'"''') - self.assertEqual(ascii('"'), """'"'""") - latin1repr = ( - "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" - "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" - "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" - "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" - "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" - "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" - "\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9" - "\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7" - "\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5" - "\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3" - "\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1" - "\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef" - "\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd" - "\\xfe\\xff'") - testrepr = ascii(''.join(map(chr, range(256)))) - self.assertEqual(testrepr, latin1repr) - # Test ascii works on wide unicode escapes without overflow. - self.assertEqual(ascii("\U00010000" * 39 + "\uffff" * 4096), - ascii("\U00010000" * 39 + "\uffff" * 4096)) - - class WrongRepr: - def __repr__(self): - return b'byte-repr' - self.assertRaises(TypeError, ascii, WrongRepr()) + self.assertEqual(ascii('abc'), "'abc'") + self.assertEqual(ascii('ab\\c'), "'ab\\\\c'") + self.assertEqual(ascii('ab\\'), "'ab\\\\'") + self.assertEqual(ascii('\\c'), "'\\\\c'") + self.assertEqual(ascii('\\'), "'\\\\'") + self.assertEqual(ascii('\n'), "'\\n'") + self.assertEqual(ascii('\r'), "'\\r'") + self.assertEqual(ascii('\t'), "'\\t'") + self.assertEqual(ascii('\b'), "'\\x08'") + self.assertEqual(ascii("'\""), """'\\'"'""") + self.assertEqual(ascii("'\""), """'\\'"'""") + self.assertEqual(ascii("'"), '''"'"''') + self.assertEqual(ascii('"'), """'"'""") + latin1repr = ( + "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" + "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" + "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" + "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" + "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" + "\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9" + "\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7" + "\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5" + "\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3" + "\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1" + "\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef" + "\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd" + "\\xfe\\xff'") + testrepr = ascii(''.join(map(chr, range(256)))) + self.assertEqual(testrepr, latin1repr) + # Test ascii works on wide unicode escapes without overflow. + self.assertEqual(ascii("\U00010000" * 39 + "\uffff" * 4096), + ascii("\U00010000" * 39 + "\uffff" * 4096)) + + class WrongRepr: + def __repr__(self): + return b'byte-repr' + self.assertRaises(TypeError, ascii, WrongRepr()) def test_repr(self): - if not sys.platform.startswith('java'): - # Test basic sanity of repr() - self.assertEqual(repr('abc'), "'abc'") - self.assertEqual(repr('ab\\c'), "'ab\\\\c'") - self.assertEqual(repr('ab\\'), "'ab\\\\'") - self.assertEqual(repr('\\c'), "'\\\\c'") - self.assertEqual(repr('\\'), "'\\\\'") - self.assertEqual(repr('\n'), "'\\n'") - self.assertEqual(repr('\r'), "'\\r'") - self.assertEqual(repr('\t'), "'\\t'") - self.assertEqual(repr('\b'), "'\\x08'") - self.assertEqual(repr("'\""), """'\\'"'""") - self.assertEqual(repr("'\""), """'\\'"'""") - self.assertEqual(repr("'"), '''"'"''') - self.assertEqual(repr('"'), """'"'""") - latin1repr = ( - "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" - "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" - "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" - "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" - "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" - "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" - "\\x9c\\x9d\\x9e\\x9f\\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9" - "\xaa\xab\xac\\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7" - "\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5" - "\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3" - "\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1" - "\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef" - "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd" - "\xfe\xff'") - testrepr = repr(''.join(map(chr, range(256)))) - self.assertEqual(testrepr, latin1repr) - # Test repr works on wide unicode escapes without overflow. - self.assertEqual(repr("\U00010000" * 39 + "\uffff" * 4096), - repr("\U00010000" * 39 + "\uffff" * 4096)) - - class WrongRepr: - def __repr__(self): - return b'byte-repr' - self.assertRaises(TypeError, repr, WrongRepr()) + # Test basic sanity of repr() + self.assertEqual(repr('abc'), "'abc'") + self.assertEqual(repr('ab\\c'), "'ab\\\\c'") + self.assertEqual(repr('ab\\'), "'ab\\\\'") + self.assertEqual(repr('\\c'), "'\\\\c'") + self.assertEqual(repr('\\'), "'\\\\'") + self.assertEqual(repr('\n'), "'\\n'") + self.assertEqual(repr('\r'), "'\\r'") + self.assertEqual(repr('\t'), "'\\t'") + self.assertEqual(repr('\b'), "'\\x08'") + self.assertEqual(repr("'\""), """'\\'"'""") + self.assertEqual(repr("'\""), """'\\'"'""") + self.assertEqual(repr("'"), '''"'"''') + self.assertEqual(repr('"'), """'"'""") + latin1repr = ( + "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" + "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" + "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" + "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" + "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" + "\\x9c\\x9d\\x9e\\x9f\\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9" + "\xaa\xab\xac\\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7" + "\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5" + "\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3" + "\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1" + "\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef" + "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd" + "\xfe\xff'") + testrepr = repr(''.join(map(chr, range(256)))) + self.assertEqual(testrepr, latin1repr) + # Test repr works on wide unicode escapes without overflow. + self.assertEqual(repr("\U00010000" * 39 + "\uffff" * 4096), + repr("\U00010000" * 39 + "\uffff" * 4096)) + + class WrongRepr: + def __repr__(self): + return b'byte-repr' + self.assertRaises(TypeError, repr, WrongRepr()) def test_iterators(self): # Make sure unicode objects have an __iter__ method @@ -180,6 +182,36 @@ def test_iterators(self): self.assertEqual(next(it), "\u3333") self.assertRaises(StopIteration, next, it) + def test_iterators_invocation(self): + cases = [type(iter('abc')), type(iter('🚀'))] + for cls in cases: + with self.subTest(cls=cls): + self.assertRaises(TypeError, cls) + + def test_iteration(self): + cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"] + for case in cases: + with self.subTest(string=case): + self.assertEqual(case, "".join(iter(case))) + + def test_exhausted_iterator(self): + cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"] + for case in cases: + with self.subTest(case=case): + iterator = iter(case) + tuple(iterator) + self.assertRaises(StopIteration, next, iterator) + + def test_pickle_iterator(self): + cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"] + for case in cases: + with self.subTest(case=case): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + it = iter(case) + with self.subTest(proto=proto): + pickled = "".join(pickle.loads(pickle.dumps(it, proto))) + self.assertEqual(case, pickled) + def test_count(self): string_tests.CommonTest.test_count(self) # check mixed argument types @@ -205,6 +237,10 @@ def test_count(self): self.checkequal(0, 'a' * 10, 'count', 'a\u0102') self.checkequal(0, 'a' * 10, 'count', 'a\U00100304') self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304') + # test subclass + class MyStr(str): + pass + self.checkequal(3, MyStr('aaa'), 'count', 'a') def test_find(self): string_tests.CommonTest.test_find(self) @@ -221,6 +257,20 @@ def test_find(self): self.checkequalnofix(9, 'abcdefghiabc', 'find', 'abc', 1) self.checkequalnofix(-1, 'abcdefghiabc', 'find', 'def', 4) + # test utf-8 non-ascii char + self.checkequal(0, 'тест', 'find', 'т') + self.checkequal(3, 'тест', 'find', 'т', 1) + self.checkequal(-1, 'тест', 'find', 'т', 1, 3) + self.checkequal(-1, 'тест', 'find', 'e') # english `e` + # test utf-8 non-ascii slice + self.checkequal(1, 'тест тест', 'find', 'ес') + self.checkequal(1, 'тест тест', 'find', 'ес', 1) + self.checkequal(1, 'тест тест', 'find', 'ес', 1, 3) + self.checkequal(6, 'тест тест', 'find', 'ес', 2) + self.checkequal(-1, 'тест тест', 'find', 'ес', 6, 7) + self.checkequal(-1, 'тест тест', 'find', 'ес', 7) + self.checkequal(-1, 'тест тест', 'find', 'ec') # english `ec` + self.assertRaises(TypeError, 'hello'.find) self.assertRaises(TypeError, 'hello'.find, 42) # test mixed kinds @@ -251,6 +301,19 @@ def test_rfind(self): self.checkequalnofix(9, 'abcdefghiabc', 'rfind', 'abc') self.checkequalnofix(12, 'abcdefghiabc', 'rfind', '') self.checkequalnofix(12, 'abcdefghiabc', 'rfind', '') + # test utf-8 non-ascii char + self.checkequal(1, 'тест', 'rfind', 'е') + self.checkequal(1, 'тест', 'rfind', 'е', 1) + self.checkequal(-1, 'тест', 'rfind', 'е', 2) + self.checkequal(-1, 'тест', 'rfind', 'e') # english `e` + # test utf-8 non-ascii slice + self.checkequal(6, 'тест тест', 'rfind', 'ес') + self.checkequal(6, 'тест тест', 'rfind', 'ес', 1) + self.checkequal(1, 'тест тест', 'rfind', 'ес', 1, 3) + self.checkequal(6, 'тест тест', 'rfind', 'ес', 2) + self.checkequal(-1, 'тест тест', 'rfind', 'ес', 6, 7) + self.checkequal(-1, 'тест тест', 'rfind', 'ес', 7) + self.checkequal(-1, 'тест тест', 'rfind', 'ec') # english `ec` # test mixed kinds self.checkequal(0, 'a' + '\u0102' * 100, 'rfind', 'a') self.checkequal(0, 'a' + '\U00100304' * 100, 'rfind', 'a') @@ -407,10 +470,10 @@ def test_split(self): def test_rsplit(self): string_tests.CommonTest.test_rsplit(self) # test mixed kinds - for left, right in ('ba', '\u0101\u0100', '\U00010301\U00010300'): + for left, right in ('ba', 'юё', '\u0101\u0100', '\U00010301\U00010300'): left *= 9 right *= 9 - for delim in ('c', '\u0102', '\U00010302'): + for delim in ('c', 'ы', '\u0102', '\U00010302'): self.checkequal([left + right], left + right, 'rsplit', delim) self.checkequal([left, right], @@ -420,6 +483,10 @@ def test_rsplit(self): self.checkequal([left, right], left + delim * 2 + right, 'rsplit', delim *2) + # Check `None` as well: + self.checkequal([left + right], + left + right, 'rsplit', None) + def test_partition(self): string_tests.MixinStrUnicodeUserStringTest.test_partition(self) # test mixed kinds @@ -619,8 +686,7 @@ def test_islower(self): def test_isupper(self): super().test_isupper() - if not sys.platform.startswith('java'): - self.checkequalnofix(False, '\u1FFc', 'isupper') + self.checkequalnofix(False, '\u1FFc', 'isupper') self.assertTrue('\u2167'.isupper()) self.assertFalse('\u2177'.isupper()) # non-BMP, uppercase @@ -757,9 +823,9 @@ def test_isidentifier(self): self.assertFalse("0".isidentifier()) @support.cpython_only - @support.requires_legacy_unicode_capi + @support.requires_legacy_unicode_capi() + @unittest.skipIf(_testcapi is None, 'need _testcapi module') def test_isidentifier_legacy(self): - import _testcapi u = '𝖀𝖓𝖎𝖈𝖔𝖉𝖊' self.assertTrue(u.isidentifier()) with warnings_helper.check_warnings(): @@ -1261,6 +1327,20 @@ def __repr__(self): self.assertRaises(ValueError, ("{" + big + "}").format) self.assertRaises(ValueError, ("{[" + big + "]}").format, [0]) + # test number formatter errors: + self.assertRaises(ValueError, '{0:x}'.format, 1j) + self.assertRaises(ValueError, '{0:x}'.format, 1.0) + self.assertRaises(ValueError, '{0:X}'.format, 1j) + self.assertRaises(ValueError, '{0:X}'.format, 1.0) + self.assertRaises(ValueError, '{0:o}'.format, 1j) + self.assertRaises(ValueError, '{0:o}'.format, 1.0) + self.assertRaises(ValueError, '{0:u}'.format, 1j) + self.assertRaises(ValueError, '{0:u}'.format, 1.0) + self.assertRaises(ValueError, '{0:i}'.format, 1j) + self.assertRaises(ValueError, '{0:i}'.format, 1.0) + self.assertRaises(ValueError, '{0:d}'.format, 1j) + self.assertRaises(ValueError, '{0:d}'.format, 1.0) + # issue 6089 self.assertRaises(ValueError, "{0[0]x}".format, [None]) self.assertRaises(ValueError, "{0[0](10)}".format, [None]) @@ -1431,10 +1511,9 @@ def test_formatting(self): self.assertEqual("%s, %s, %i, %f, %5.2f" % ("abc", "abc", -1, -2, 3.5), 'abc, abc, -1, -2.000000, 3.50') self.assertEqual("%s, %s, %i, %f, %5.2f" % ("abc", "abc", -1, -2, 3.57), 'abc, abc, -1, -2.000000, 3.57') self.assertEqual("%s, %s, %i, %f, %5.2f" % ("abc", "abc", -1, -2, 1003.57), 'abc, abc, -1, -2.000000, 1003.57') - if not sys.platform.startswith('java'): - self.assertEqual("%r, %r" % (b"abc", "abc"), "b'abc', 'abc'") - self.assertEqual("%r" % ("\u1234",), "'\u1234'") - self.assertEqual("%a" % ("\u1234",), "'\\u1234'") + self.assertEqual("%r, %r" % (b"abc", "abc"), "b'abc', 'abc'") + self.assertEqual("%r" % ("\u1234",), "'\u1234'") + self.assertEqual("%a" % ("\u1234",), "'\\u1234'") self.assertEqual("%(x)s, %(y)s" % {'x':"abc", 'y':"def"}, 'abc, def') self.assertEqual("%(x)s, %(\xfc)s" % {'x':"abc", '\xfc':"def"}, 'abc, def') @@ -1503,38 +1582,60 @@ def __int__(self): self.assertEqual('%X' % letter_m, '6D') self.assertEqual('%o' % letter_m, '155') self.assertEqual('%c' % letter_m, 'm') - self.assertRaisesRegex(TypeError, '%x format: an integer is required, not float', operator.mod, '%x', 3.14), - self.assertRaisesRegex(TypeError, '%X format: an integer is required, not float', operator.mod, '%X', 2.11), - self.assertRaisesRegex(TypeError, '%o format: an integer is required, not float', operator.mod, '%o', 1.79), - self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi), - self.assertRaises(TypeError, operator.mod, '%c', pi), + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not float', operator.mod, '%x', 3.14) + self.assertRaisesRegex(TypeError, '%X format: an integer is required, not float', operator.mod, '%X', 2.11) + self.assertRaisesRegex(TypeError, '%o format: an integer is required, not float', operator.mod, '%o', 1.79) + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi) + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not complex', operator.mod, '%x', 3j) + self.assertRaisesRegex(TypeError, '%X format: an integer is required, not complex', operator.mod, '%X', 2j) + self.assertRaisesRegex(TypeError, '%o format: an integer is required, not complex', operator.mod, '%o', 1j) + self.assertRaisesRegex(TypeError, '%u format: a real number is required, not complex', operator.mod, '%u', 3j) + self.assertRaisesRegex(TypeError, '%i format: a real number is required, not complex', operator.mod, '%i', 2j) + self.assertRaisesRegex(TypeError, '%d format: a real number is required, not complex', operator.mod, '%d', 1j) + self.assertRaisesRegex(TypeError, '%c requires int or char', operator.mod, '%c', pi) + + class RaisingNumber: + def __int__(self): + raise RuntimeError('int') # should not be `TypeError` + def __index__(self): + raise RuntimeError('index') # should not be `TypeError` + rn = RaisingNumber() + self.assertRaisesRegex(RuntimeError, 'int', operator.mod, '%d', rn) + self.assertRaisesRegex(RuntimeError, 'int', operator.mod, '%i', rn) + self.assertRaisesRegex(RuntimeError, 'int', operator.mod, '%u', rn) + self.assertRaisesRegex(RuntimeError, 'index', operator.mod, '%x', rn) + self.assertRaisesRegex(RuntimeError, 'index', operator.mod, '%X', rn) + self.assertRaisesRegex(RuntimeError, 'index', operator.mod, '%o', rn) - # TODO: RUSTPYTHON, AssertionError: '...15...' != '...Int.IDES...' - @unittest.expectedFailure def test_formatting_with_enum(self): # issue18780 import enum class Float(float, enum.Enum): + # a mixed-in type will use the name for %s etc. PI = 3.1415926 class Int(enum.IntEnum): + # IntEnum uses the value and not the name for %s etc. IDES = 15 - class Str(str, enum.Enum): + class Str(enum.StrEnum): + # StrEnum uses the value and not the name for %s etc. ABC = 'abc' # Testing Unicode formatting strings... self.assertEqual("%s, %s" % (Str.ABC, Str.ABC), - 'Str.ABC, Str.ABC') + 'abc, abc') self.assertEqual("%s, %s, %d, %i, %u, %f, %5.2f" % (Str.ABC, Str.ABC, Int.IDES, Int.IDES, Int.IDES, Float.PI, Float.PI), - 'Str.ABC, Str.ABC, 15, 15, 15, 3.141593, 3.14') + 'abc, abc, 15, 15, 15, 3.141593, 3.14') # formatting jobs delegated from the string implementation: self.assertEqual('...%(foo)s...' % {'foo':Str.ABC}, - '...Str.ABC...') + '...abc...') + self.assertEqual('...%(foo)r...' % {'foo':Int.IDES}, + '......') self.assertEqual('...%(foo)s...' % {'foo':Int.IDES}, - '...Int.IDES...') + '...15...') self.assertEqual('...%(foo)i...' % {'foo':Int.IDES}, '...15...') self.assertEqual('...%(foo)d...' % {'foo':Int.IDES}, @@ -1559,9 +1660,9 @@ def __rmod__(self, other): "Success, self.__rmod__('lhs %% %r') was called") @support.cpython_only + @unittest.skipIf(_testcapi is None, 'need _testcapi module') def test_formatting_huge_precision_c_limits(self): - from _testcapi import INT_MAX - format_string = "%.{}f".format(INT_MAX + 1) + format_string = "%.{}f".format(_testcapi.INT_MAX + 1) with self.assertRaises(ValueError): result = format_string % 2.34 @@ -1627,29 +1728,27 @@ def __str__(self): # unicode(obj, encoding, error) tests (this maps to # PyUnicode_FromEncodedObject() at C level) - if not sys.platform.startswith('java'): - self.assertRaises( - TypeError, - str, - 'decoding unicode is not supported', - 'utf-8', - 'strict' - ) + self.assertRaises( + TypeError, + str, + 'decoding unicode is not supported', + 'utf-8', + 'strict' + ) self.assertEqual( str(b'strings are decoded to unicode', 'utf-8', 'strict'), 'strings are decoded to unicode' ) - if not sys.platform.startswith('java'): - self.assertEqual( - str( - memoryview(b'character buffers are decoded to unicode'), - 'utf-8', - 'strict' - ), - 'character buffers are decoded to unicode' - ) + self.assertEqual( + str( + memoryview(b'character buffers are decoded to unicode'), + 'utf-8', + 'strict' + ), + 'character buffers are decoded to unicode' + ) self.assertRaises(TypeError, str, 42, 42, 42) @@ -2347,12 +2446,7 @@ class s1: def __repr__(self): return '\\n' - class s2: - def __repr__(self): - return '\\n' - self.assertEqual(repr(s1()), '\\n') - self.assertEqual(repr(s2()), '\\n') def test_printable_repr(self): self.assertEqual(repr('\U00010000'), "'%c'" % (0x10000,)) # printable @@ -2374,20 +2468,19 @@ def test_expandtabs_optimization(self): @unittest.skip("TODO: RUSTPYTHON, aborted: memory allocation of 9223372036854775759 bytes failed") def test_raiseMemError(self): - if struct.calcsize('P') == 8: - # 64 bits pointers - ascii_struct_size = 48 - compact_struct_size = 72 - else: - # 32 bits pointers - ascii_struct_size = 24 - compact_struct_size = 36 + asciifields = "nnb" + compactfields = asciifields + "nP" + ascii_struct_size = support.calcobjsize(asciifields) + compact_struct_size = support.calcobjsize(compactfields) for char in ('a', '\xe9', '\u20ac', '\U0010ffff'): code = ord(char) - if code < 0x100: + if code < 0x80: char_size = 1 # sizeof(Py_UCS1) struct_size = ascii_struct_size + elif code < 0x100: + char_size = 1 # sizeof(Py_UCS1) + struct_size = compact_struct_size elif code < 0x10000: char_size = 2 # sizeof(Py_UCS2) struct_size = compact_struct_size @@ -2399,8 +2492,18 @@ def test_raiseMemError(self): # be allocatable, given enough memory. maxlen = ((sys.maxsize - struct_size) // char_size) alloc = lambda: char * maxlen - self.assertRaises(MemoryError, alloc) - self.assertRaises(MemoryError, alloc) + with self.subTest( + char=char, + struct_size=struct_size, + char_size=char_size + ): + # self-check + self.assertEqual( + sys.getsizeof(char * 42), + struct_size + (char_size * (42 + 1)) + ) + self.assertRaises(MemoryError, alloc) + self.assertRaises(MemoryError, alloc) def test_format_subclass(self): class S(str): @@ -2430,22 +2533,22 @@ def test_getnewargs(self): self.assertEqual(len(args), 1) @support.cpython_only - @support.requires_legacy_unicode_capi + @support.requires_legacy_unicode_capi() + @unittest.skipIf(_testcapi is None, 'need _testcapi module') def test_resize(self): - from _testcapi import getargs_u for length in range(1, 100, 7): # generate a fresh string (refcount=1) text = 'a' * length + 'b' # fill wstr internal field with self.assertWarns(DeprecationWarning): - abc = getargs_u(text) + abc = _testcapi.getargs_u(text) self.assertEqual(abc, text) # resize text: wstr field must be cleared and then recomputed text += 'c' with self.assertWarns(DeprecationWarning): - abcdef = getargs_u(text) + abcdef = _testcapi.getargs_u(text) self.assertNotEqual(abc, abcdef) self.assertEqual(abcdef, text) @@ -2592,473 +2695,6 @@ def test_check_encoding_errors(self): self.assertEqual(proc.rc, 10, proc) -class CAPITest(unittest.TestCase): - - # Test PyUnicode_FromFormat() - def test_from_format(self): - import_helper.import_module('ctypes') - from ctypes import ( - c_char_p, - pythonapi, py_object, sizeof, - c_int, c_long, c_longlong, c_ssize_t, - c_uint, c_ulong, c_ulonglong, c_size_t, c_void_p) - name = "PyUnicode_FromFormat" - _PyUnicode_FromFormat = getattr(pythonapi, name) - _PyUnicode_FromFormat.argtypes = (c_char_p,) - _PyUnicode_FromFormat.restype = py_object - - def PyUnicode_FromFormat(format, *args): - cargs = tuple( - py_object(arg) if isinstance(arg, str) else arg - for arg in args) - return _PyUnicode_FromFormat(format, *cargs) - - def check_format(expected, format, *args): - text = PyUnicode_FromFormat(format, *args) - self.assertEqual(expected, text) - - # ascii format, non-ascii argument - check_format('ascii\x7f=unicode\xe9', - b'ascii\x7f=%U', 'unicode\xe9') - - # non-ascii format, ascii argument: ensure that PyUnicode_FromFormatV() - # raises an error - self.assertRaisesRegex(ValueError, - r'^PyUnicode_FromFormatV\(\) expects an ASCII-encoded format ' - 'string, got a non-ASCII byte: 0xe9$', - PyUnicode_FromFormat, b'unicode\xe9=%s', 'ascii') - - # test "%c" - check_format('\uabcd', - b'%c', c_int(0xabcd)) - check_format('\U0010ffff', - b'%c', c_int(0x10ffff)) - with self.assertRaises(OverflowError): - PyUnicode_FromFormat(b'%c', c_int(0x110000)) - # Issue #18183 - check_format('\U00010000\U00100000', - b'%c%c', c_int(0x10000), c_int(0x100000)) - - # test "%" - check_format('%', - b'%') - check_format('%', - b'%%') - check_format('%s', - b'%%s') - check_format('[%]', - b'[%%]') - check_format('%abc', - b'%%%s', b'abc') - - # truncated string - check_format('abc', - b'%.3s', b'abcdef') - check_format('abc[\ufffd', - b'%.5s', 'abc[\u20ac]'.encode('utf8')) - check_format("'\\u20acABC'", - b'%A', '\u20acABC') - check_format("'\\u20", - b'%.5A', '\u20acABCDEF') - check_format("'\u20acABC'", - b'%R', '\u20acABC') - check_format("'\u20acA", - b'%.3R', '\u20acABCDEF') - check_format('\u20acAB', - b'%.3S', '\u20acABCDEF') - check_format('\u20acAB', - b'%.3U', '\u20acABCDEF') - check_format('\u20acAB', - b'%.3V', '\u20acABCDEF', None) - check_format('abc[\ufffd', - b'%.5V', None, 'abc[\u20ac]'.encode('utf8')) - - # following tests comes from #7330 - # test width modifier and precision modifier with %S - check_format("repr= abc", - b'repr=%5S', 'abc') - check_format("repr=ab", - b'repr=%.2S', 'abc') - check_format("repr= ab", - b'repr=%5.2S', 'abc') - - # test width modifier and precision modifier with %R - check_format("repr= 'abc'", - b'repr=%8R', 'abc') - check_format("repr='ab", - b'repr=%.3R', 'abc') - check_format("repr= 'ab", - b'repr=%5.3R', 'abc') - - # test width modifier and precision modifier with %A - check_format("repr= 'abc'", - b'repr=%8A', 'abc') - check_format("repr='ab", - b'repr=%.3A', 'abc') - check_format("repr= 'ab", - b'repr=%5.3A', 'abc') - - # test width modifier and precision modifier with %s - check_format("repr= abc", - b'repr=%5s', b'abc') - check_format("repr=ab", - b'repr=%.2s', b'abc') - check_format("repr= ab", - b'repr=%5.2s', b'abc') - - # test width modifier and precision modifier with %U - check_format("repr= abc", - b'repr=%5U', 'abc') - check_format("repr=ab", - b'repr=%.2U', 'abc') - check_format("repr= ab", - b'repr=%5.2U', 'abc') - - # test width modifier and precision modifier with %V - check_format("repr= abc", - b'repr=%5V', 'abc', b'123') - check_format("repr=ab", - b'repr=%.2V', 'abc', b'123') - check_format("repr= ab", - b'repr=%5.2V', 'abc', b'123') - check_format("repr= 123", - b'repr=%5V', None, b'123') - check_format("repr=12", - b'repr=%.2V', None, b'123') - check_format("repr= 12", - b'repr=%5.2V', None, b'123') - - # test integer formats (%i, %d, %u) - check_format('010', - b'%03i', c_int(10)) - check_format('0010', - b'%0.4i', c_int(10)) - check_format('-123', - b'%i', c_int(-123)) - check_format('-123', - b'%li', c_long(-123)) - check_format('-123', - b'%lli', c_longlong(-123)) - check_format('-123', - b'%zi', c_ssize_t(-123)) - - check_format('-123', - b'%d', c_int(-123)) - check_format('-123', - b'%ld', c_long(-123)) - check_format('-123', - b'%lld', c_longlong(-123)) - check_format('-123', - b'%zd', c_ssize_t(-123)) - - check_format('123', - b'%u', c_uint(123)) - check_format('123', - b'%lu', c_ulong(123)) - check_format('123', - b'%llu', c_ulonglong(123)) - check_format('123', - b'%zu', c_size_t(123)) - - # test long output - min_longlong = -(2 ** (8 * sizeof(c_longlong) - 1)) - max_longlong = -min_longlong - 1 - check_format(str(min_longlong), - b'%lld', c_longlong(min_longlong)) - check_format(str(max_longlong), - b'%lld', c_longlong(max_longlong)) - max_ulonglong = 2 ** (8 * sizeof(c_ulonglong)) - 1 - check_format(str(max_ulonglong), - b'%llu', c_ulonglong(max_ulonglong)) - PyUnicode_FromFormat(b'%p', c_void_p(-1)) - - # test padding (width and/or precision) - check_format('123'.rjust(10, '0'), - b'%010i', c_int(123)) - check_format('123'.rjust(100), - b'%100i', c_int(123)) - check_format('123'.rjust(100, '0'), - b'%.100i', c_int(123)) - check_format('123'.rjust(80, '0').rjust(100), - b'%100.80i', c_int(123)) - - check_format('123'.rjust(10, '0'), - b'%010u', c_uint(123)) - check_format('123'.rjust(100), - b'%100u', c_uint(123)) - check_format('123'.rjust(100, '0'), - b'%.100u', c_uint(123)) - check_format('123'.rjust(80, '0').rjust(100), - b'%100.80u', c_uint(123)) - - check_format('123'.rjust(10, '0'), - b'%010x', c_int(0x123)) - check_format('123'.rjust(100), - b'%100x', c_int(0x123)) - check_format('123'.rjust(100, '0'), - b'%.100x', c_int(0x123)) - check_format('123'.rjust(80, '0').rjust(100), - b'%100.80x', c_int(0x123)) - - # test %A - check_format(r"%A:'abc\xe9\uabcd\U0010ffff'", - b'%%A:%A', 'abc\xe9\uabcd\U0010ffff') - - # test %V - check_format('repr=abc', - b'repr=%V', 'abc', b'xyz') - - # Test string decode from parameter of %s using utf-8. - # b'\xe4\xba\xba\xe6\xb0\x91' is utf-8 encoded byte sequence of - # '\u4eba\u6c11' - check_format('repr=\u4eba\u6c11', - b'repr=%V', None, b'\xe4\xba\xba\xe6\xb0\x91') - - #Test replace error handler. - check_format('repr=abc\ufffd', - b'repr=%V', None, b'abc\xff') - - # not supported: copy the raw format string. these tests are just here - # to check for crashes and should not be considered as specifications - check_format('%s', - b'%1%s', b'abc') - check_format('%1abc', - b'%1abc') - check_format('%+i', - b'%+i', c_int(10)) - check_format('%.%s', - b'%.%s', b'abc') - - # Issue #33817: empty strings - check_format('', - b'') - check_format('', - b'%s', b'') - - # Test PyUnicode_AsWideChar() - @support.cpython_only - def test_aswidechar(self): - from _testcapi import unicode_aswidechar - import_helper.import_module('ctypes') - from ctypes import c_wchar, sizeof - - wchar, size = unicode_aswidechar('abcdef', 2) - self.assertEqual(size, 2) - self.assertEqual(wchar, 'ab') - - wchar, size = unicode_aswidechar('abc', 3) - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc') - - wchar, size = unicode_aswidechar('abc', 4) - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc\0') - - wchar, size = unicode_aswidechar('abc', 10) - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc\0') - - wchar, size = unicode_aswidechar('abc\0def', 20) - self.assertEqual(size, 7) - self.assertEqual(wchar, 'abc\0def\0') - - nonbmp = chr(0x10ffff) - if sizeof(c_wchar) == 2: - buflen = 3 - nchar = 2 - else: # sizeof(c_wchar) == 4 - buflen = 2 - nchar = 1 - wchar, size = unicode_aswidechar(nonbmp, buflen) - self.assertEqual(size, nchar) - self.assertEqual(wchar, nonbmp + '\0') - - # Test PyUnicode_AsWideCharString() - @support.cpython_only - def test_aswidecharstring(self): - from _testcapi import unicode_aswidecharstring - import_helper.import_module('ctypes') - from ctypes import c_wchar, sizeof - - wchar, size = unicode_aswidecharstring('abc') - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc\0') - - wchar, size = unicode_aswidecharstring('abc\0def') - self.assertEqual(size, 7) - self.assertEqual(wchar, 'abc\0def\0') - - nonbmp = chr(0x10ffff) - if sizeof(c_wchar) == 2: - nchar = 2 - else: # sizeof(c_wchar) == 4 - nchar = 1 - wchar, size = unicode_aswidecharstring(nonbmp) - self.assertEqual(size, nchar) - self.assertEqual(wchar, nonbmp + '\0') - - # Test PyUnicode_AsUCS4() - @support.cpython_only - def test_asucs4(self): - from _testcapi import unicode_asucs4 - for s in ['abc', '\xa1\xa2', '\u4f60\u597d', 'a\U0001f600', - 'a\ud800b\udfffc', '\ud834\udd1e']: - l = len(s) - self.assertEqual(unicode_asucs4(s, l, True), s+'\0') - self.assertEqual(unicode_asucs4(s, l, False), s+'\uffff') - self.assertEqual(unicode_asucs4(s, l+1, True), s+'\0\uffff') - self.assertEqual(unicode_asucs4(s, l+1, False), s+'\0\uffff') - self.assertRaises(SystemError, unicode_asucs4, s, l-1, True) - self.assertRaises(SystemError, unicode_asucs4, s, l-2, False) - s = '\0'.join([s, s]) - self.assertEqual(unicode_asucs4(s, len(s), True), s+'\0') - self.assertEqual(unicode_asucs4(s, len(s), False), s+'\uffff') - - # Test PyUnicode_AsUTF8() - @support.cpython_only - def test_asutf8(self): - from _testcapi import unicode_asutf8 - - bmp = '\u0100' - bmp2 = '\uffff' - nonbmp = chr(0x10ffff) - - self.assertEqual(unicode_asutf8(bmp), b'\xc4\x80') - self.assertEqual(unicode_asutf8(bmp2), b'\xef\xbf\xbf') - self.assertEqual(unicode_asutf8(nonbmp), b'\xf4\x8f\xbf\xbf') - self.assertRaises(UnicodeEncodeError, unicode_asutf8, 'a\ud800b\udfffc') - - # Test PyUnicode_AsUTF8AndSize() - @support.cpython_only - def test_asutf8andsize(self): - from _testcapi import unicode_asutf8andsize - - bmp = '\u0100' - bmp2 = '\uffff' - nonbmp = chr(0x10ffff) - - self.assertEqual(unicode_asutf8andsize(bmp), (b'\xc4\x80', 2)) - self.assertEqual(unicode_asutf8andsize(bmp2), (b'\xef\xbf\xbf', 3)) - self.assertEqual(unicode_asutf8andsize(nonbmp), (b'\xf4\x8f\xbf\xbf', 4)) - self.assertRaises(UnicodeEncodeError, unicode_asutf8andsize, 'a\ud800b\udfffc') - - # Test PyUnicode_FindChar() - @support.cpython_only - def test_findchar(self): - from _testcapi import unicode_findchar - - for str in "\xa1", "\u8000\u8080", "\ud800\udc02", "\U0001f100\U0001f1f1": - for i, ch in enumerate(str): - self.assertEqual(unicode_findchar(str, ord(ch), 0, len(str), 1), i) - self.assertEqual(unicode_findchar(str, ord(ch), 0, len(str), -1), i) - - str = "!>_= end - self.assertEqual(unicode_findchar(str, ord('!'), 0, 0, 1), -1) - self.assertEqual(unicode_findchar(str, ord('!'), len(str), 0, 1), -1) - # negative - self.assertEqual(unicode_findchar(str, ord('!'), -len(str), -1, 1), 0) - self.assertEqual(unicode_findchar(str, ord('!'), -len(str), -1, -1), 0) - - # Test PyUnicode_CopyCharacters() - @support.cpython_only - def test_copycharacters(self): - from _testcapi import unicode_copycharacters - - strings = [ - 'abcde', '\xa1\xa2\xa3\xa4\xa5', - '\u4f60\u597d\u4e16\u754c\uff01', - '\U0001f600\U0001f601\U0001f602\U0001f603\U0001f604' - ] - - for idx, from_ in enumerate(strings): - # wide -> narrow: exceed maxchar limitation - for to in strings[:idx]: - self.assertRaises( - SystemError, - unicode_copycharacters, to, 0, from_, 0, 5 - ) - # same kind - for from_start in range(5): - self.assertEqual( - unicode_copycharacters(from_, 0, from_, from_start, 5), - (from_[from_start:from_start+5].ljust(5, '\0'), - 5-from_start) - ) - for to_start in range(5): - self.assertEqual( - unicode_copycharacters(from_, to_start, from_, to_start, 5), - (from_[to_start:to_start+5].rjust(5, '\0'), - 5-to_start) - ) - # narrow -> wide - # Tests omitted since this creates invalid strings. - - s = strings[0] - self.assertRaises(IndexError, unicode_copycharacters, s, 6, s, 0, 5) - self.assertRaises(IndexError, unicode_copycharacters, s, -1, s, 0, 5) - self.assertRaises(IndexError, unicode_copycharacters, s, 0, s, 6, 5) - self.assertRaises(IndexError, unicode_copycharacters, s, 0, s, -1, 5) - self.assertRaises(SystemError, unicode_copycharacters, s, 1, s, 0, 5) - self.assertRaises(SystemError, unicode_copycharacters, s, 0, s, 0, -1) - self.assertRaises(SystemError, unicode_copycharacters, s, 0, b'', 0, 0) - - @support.cpython_only - @support.requires_legacy_unicode_capi - def test_encode_decimal(self): - from _testcapi import unicode_encodedecimal - with warnings_helper.check_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(unicode_encodedecimal('123'), - b'123') - self.assertEqual(unicode_encodedecimal('\u0663.\u0661\u0664'), - b'3.14') - self.assertEqual(unicode_encodedecimal( - "\N{EM SPACE}3.14\N{EN SPACE}"), b' 3.14 ') - self.assertRaises(UnicodeEncodeError, - unicode_encodedecimal, "123\u20ac", "strict") - self.assertRaisesRegex( - ValueError, - "^'decimal' codec can't encode character", - unicode_encodedecimal, "123\u20ac", "replace") - - @support.cpython_only - @support.requires_legacy_unicode_capi - def test_transform_decimal(self): - from _testcapi import unicode_transformdecimaltoascii as transform_decimal - with warnings_helper.check_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(transform_decimal('123'), - '123') - self.assertEqual(transform_decimal('\u0663.\u0661\u0664'), - '3.14') - self.assertEqual(transform_decimal("\N{EM SPACE}3.14\N{EN SPACE}"), - "\N{EM SPACE}3.14\N{EN SPACE}") - self.assertEqual(transform_decimal('123\u20ac'), - '123\u20ac') - - @support.cpython_only - def test_pep393_utf8_caching_bug(self): - # Issue #25709: Problem with string concatenation and utf-8 cache - from _testcapi import getargs_s_hash - for k in 0x24, 0xa4, 0x20ac, 0x1f40d: - s = '' - for i in range(5): - # Due to CPython specific optimization the 's' string can be - # resized in-place. - s += chr(k) - # Parsing with the "s#" format code calls indirectly - # PyUnicode_AsUTF8AndSize() which creates the UTF-8 - # encoded string cached in the Unicode object. - self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1)) - # Check that the second call returns the same result - self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1)) - class StringModuleTest(unittest.TestCase): def test_formatter_parser(self): def parse(format): @@ -3109,6 +2745,30 @@ def split(name): ]]) self.assertRaises(TypeError, _string.formatter_field_name_split, 1) + def test_str_subclass_attr(self): + + name = StrSubclass("name") + name2 = StrSubclass("name2") + class Bag: + pass + + o = Bag() + with self.assertRaises(AttributeError): + delattr(o, name) + setattr(o, name, 1) + self.assertEqual(o.name, 1) + o.name = 2 + self.assertEqual(list(o.__dict__), [name]) + + with self.assertRaises(AttributeError): + delattr(o, name2) + with self.assertRaises(AttributeError): + del o.name2 + setattr(o, name2, 3) + self.assertEqual(o.name2, 3) + o.name2 = 4 + self.assertEqual(list(o.__dict__), [name, name2]) + if __name__ == "__main__": unittest.main() From b3af69ee2d1ebcd0017632110faadc9d8df1872a Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Fri, 20 Oct 2023 18:11:45 -0700 Subject: [PATCH 15/19] Update test_functools from Python 3.12 --- Lib/test/test_functools.py | 682 ++++++++++++++++++++++++++++++++----- 1 file changed, 594 insertions(+), 88 deletions(-) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 71748df8ed..fb2dcf7a51 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -13,19 +13,29 @@ import typing import unittest import unittest.mock +import weakref +import gc from weakref import proxy import contextlib +from inspect import Signature from test.support import import_helper from test.support import threading_helper import functools -py_functools = import_helper.import_fresh_module('functools', blocked=['_functools']) -c_functools = import_helper.import_fresh_module('functools', fresh=['_functools']) +py_functools = import_helper.import_fresh_module('functools', + blocked=['_functools']) +c_functools = import_helper.import_fresh_module('functools', + fresh=['_functools']) decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal']) +_partial_types = [py_functools.partial] +if c_functools: + _partial_types.append(c_functools.partial) + + @contextlib.contextmanager def replaced_module(name, replacement): original_module = sys.modules[name] @@ -162,6 +172,7 @@ def test_weakref(self): p = proxy(f) self.assertEqual(f.func, p.func) f = None + support.gc_collect() # For PyPy or other GCs. self.assertRaises(ReferenceError, getattr, p, 'func') def test_with_bound_and_unbound_methods(self): @@ -196,7 +207,7 @@ def test_repr(self): kwargs = {'a': object(), 'b': object()} kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 'b={b!r}, a={a!r}'.format_map(kwargs)] - if self.partial in (c_functools.partial, py_functools.partial): + if self.partial in _partial_types: name = 'functools.partial' else: name = self.partial.__name__ @@ -218,7 +229,7 @@ def test_repr(self): for kwargs_repr in kwargs_reprs]) def test_recursive_repr(self): - if self.partial in (c_functools.partial, py_functools.partial): + if self.partial in _partial_types: name = 'functools.partial' else: name = self.partial.__name__ @@ -245,7 +256,7 @@ def test_recursive_repr(self): f.__setstate__((capture, (), {}, {})) def test_pickle(self): - with self.AllowPickle(): + with replaced_module('functools', self.module): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -328,7 +339,7 @@ def test_setstate_subclasses(self): self.assertIs(type(r[0]), tuple) def test_recursive_pickle(self): - with self.AllowPickle(): + with replaced_module('functools', self.module): f = self.partial(capture) f.__setstate__((f, (), {}, {})) try: @@ -382,24 +393,9 @@ def __getitem__(self, key): @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialC(TestPartial, unittest.TestCase): if c_functools: + module = c_functools partial = c_functools.partial - class AllowPickle: - def __enter__(self): - return self - def __exit__(self, type, value, tb): - return False - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle(self): - super().test_pickle() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_recursive_pickle(self): - super().test_recursive_pickle() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_attributes_unwritable(self): @@ -444,15 +440,9 @@ def __str__(self): class TestPartialPy(TestPartial, unittest.TestCase): + module = py_functools partial = py_functools.partial - class AllowPickle: - def __init__(self): - self._cm = replaced_module("functools", py_functools) - def __enter__(self): - return self._cm.__enter__() - def __exit__(self, type, value, tb): - return self._cm.__exit__(type, value, tb) if c_functools: class CPartialSubclass(c_functools.partial): @@ -579,11 +569,9 @@ class B(object): with self.assertRaises(TypeError): class B: method = functools.partialmethod() - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): class B: method = functools.partialmethod(func=capture, a=1) - b = B() - self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3})) def test_repr(self): self.assertEqual(repr(vars(self.A)['both']), @@ -634,6 +622,8 @@ def check_wrapper(self, wrapper, wrapped, def _default_update(self): + # XXX: RUSTPYTHON; f[T] is not supported yet + # def f[T](a:'This is a new annotation'): def f(a:'This is a new annotation'): """This is a test""" pass @@ -644,15 +634,19 @@ def wrapper(b:'This is the prior annotation'): functools.update_wrapper(wrapper, f) return wrapper, f + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_default_update(self): wrapper, f = self._default_update() self.check_wrapper(wrapper, f) + T, = f.__type_params__ self.assertIs(wrapper.__wrapped__, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') self.assertNotIn('b', wrapper.__annotations__) + self.assertEqual(wrapper.__type_params__, (T,)) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -959,6 +953,10 @@ def mycmp(x, y): self.assertRaises(TypeError, hash, k) self.assertNotIsInstance(k, collections.abc.Hashable) + def test_cmp_to_signature(self): + self.assertEqual(str(Signature.from_callable(self.cmp_to_key)), + '(mycmp)') + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): @@ -1000,6 +998,18 @@ def test_sort_int(self): def test_sort_int_str(self): super().test_sort_int_str() + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_cmp_to_signature(self): + super().test_cmp_to_signature() + + @support.cpython_only + def test_disallow_instantiation(self): + # Ensure that the type disallows instantiation (bpo-43916) + support.check_disallow_instantiation( + self, type(c_functools.cmp_to_key(None)) + ) + class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): cmp_to_key = staticmethod(py_functools.cmp_to_key) @@ -1093,6 +1103,73 @@ def test_no_operations_defined(self): class A: pass + def test_notimplemented(self): + # Verify NotImplemented results are correctly handled + @functools.total_ordering + class ImplementsLessThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThan): + return self.value == other.value + return False + def __lt__(self, other): + if isinstance(other, ImplementsLessThan): + return self.value < other.value + return NotImplemented + + @functools.total_ordering + class ImplementsLessThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value == other.value + return False + def __le__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value <= other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value == other.value + return False + def __gt__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value > other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value == other.value + return False + def __ge__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value >= other.value + return NotImplemented + + self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented) + self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented) + self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented) + self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented) + self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented) + self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented) + self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented) + self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented) + self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented) + self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented) + self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented) + self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented) + def test_type_error_when_not_implemented(self): # bug 10042; ensure stack overflow does not occur # when decorated types return NotImplemented @@ -1208,6 +1285,34 @@ def test_pickle(self): method_copy = pickle.loads(pickle.dumps(method, proto)) self.assertIs(method_copy, method) + + def test_total_ordering_for_metaclasses_issue_44605(self): + + @functools.total_ordering + class SortableMeta(type): + def __new__(cls, name, bases, ns): + return super().__new__(cls, name, bases, ns) + + def __lt__(self, other): + if not isinstance(other, SortableMeta): + pass + return self.__name__ < other.__name__ + + def __eq__(self, other): + if not isinstance(other, SortableMeta): + pass + return self.__name__ == other.__name__ + + class B(metaclass=SortableMeta): + pass + + class A(metaclass=SortableMeta): + pass + + self.assertTrue(A < B) + self.assertFalse(A > B) + + @functools.total_ordering class Orderable_LT: def __init__(self, value): @@ -1218,6 +1323,25 @@ def __eq__(self, other): return self.value == other.value +class TestCache: + # This tests that the pass-through is working as designed. + # The underlying functionality is tested in TestLRU. + + def test_cache(self): + @self.module.cache + def fib(n): + if n < 2: + return n + return fib(n-1) + fib(n-2) + self.assertEqual([fib(n) for n in range(16)], + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) + self.assertEqual(fib.cache_info(), + self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + fib.cache_clear() + self.assertEqual(fib.cache_info(), + self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + + class TestLRU: def test_lru(self): @@ -1411,7 +1535,7 @@ def test_lru_reentrancy_with_len(self): def test_lru_star_arg_handling(self): # Test regression that arose in ea064ff3c10f - @functools.lru_cache() + @self.module.lru_cache() def f(*args): return args @@ -1423,11 +1547,11 @@ def test_lru_type_error(self): # lru_cache was leaking when one of the arguments # wasn't cacheable. - @functools.lru_cache(maxsize=None) + @self.module.lru_cache(maxsize=None) def infinite_cache(o): pass - @functools.lru_cache(maxsize=10) + @self.module.lru_cache(maxsize=10) def limited_cache(o): pass @@ -1492,6 +1616,33 @@ def square(x): self.assertEqual(square.cache_info().hits, 4) self.assertEqual(square.cache_info().misses, 4) + def test_lru_cache_typed_is_not_recursive(self): + cached = self.module.lru_cache(typed=True)(repr) + + self.assertEqual(cached(1), '1') + self.assertEqual(cached(True), 'True') + self.assertEqual(cached(1.0), '1.0') + self.assertEqual(cached(0), '0') + self.assertEqual(cached(False), 'False') + self.assertEqual(cached(0.0), '0.0') + + self.assertEqual(cached((1,)), '(1,)') + self.assertEqual(cached((True,)), '(1,)') + self.assertEqual(cached((1.0,)), '(1,)') + self.assertEqual(cached((0,)), '(0,)') + self.assertEqual(cached((False,)), '(0,)') + self.assertEqual(cached((0.0,)), '(0,)') + + class T(tuple): + pass + + self.assertEqual(cached(T((1,))), '(1,)') + self.assertEqual(cached(T((True,))), '(1,)') + self.assertEqual(cached(T((1.0,))), '(1,)') + self.assertEqual(cached(T((0,))), '(0,)') + self.assertEqual(cached(T((False,))), '(0,)') + self.assertEqual(cached(T((0.0,))), '(0,)') + def test_lru_with_keyword_args(self): @self.module.lru_cache() def fib(n): @@ -1542,6 +1693,7 @@ def f(zomg: 'zomg_annotation'): # TODO: RUSTPYTHON @unittest.expectedFailure + @threading_helper.requires_working_threading() def test_lru_cache_threaded(self): n, m = 5, 11 def orig(x, y): @@ -1590,6 +1742,7 @@ def clear(): finally: sys.setswitchinterval(orig_si) + @threading_helper.requires_working_threading() def test_lru_cache_threaded2(self): # Simultaneous call with the same arguments n, m = 5, 7 @@ -1617,6 +1770,7 @@ def test(): pause.reset() self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) + @threading_helper.requires_working_threading() def test_lru_cache_threaded3(self): @self.module.lru_cache(maxsize=2) def f(x): @@ -1717,14 +1871,62 @@ def orig(x, y): f_copy = copy.deepcopy(f) self.assertIs(f_copy, f) + def test_lru_cache_parameters(self): + @self.module.lru_cache(maxsize=2) + def f(): + return 1 + self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) + + @self.module.lru_cache(maxsize=1000, typed=True) + def f(): + return 1 + self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) + + def test_lru_cache_weakrefable(self): + @self.module.lru_cache + def test_function(x): + return x + + class A: + @self.module.lru_cache + def test_method(self, x): + return (self, x) + + @staticmethod + @self.module.lru_cache + def test_staticmethod(x): + return (self, x) + + refs = [weakref.ref(test_function), + weakref.ref(A.test_method), + weakref.ref(A.test_staticmethod)] + + for ref in refs: + self.assertIsNotNone(ref()) + + del A + del test_function + gc.collect() + + for ref in refs: + self.assertIsNone(ref()) + + def test_common_signatures(self): + def orig(): ... + lru = self.module.lru_cache(1)(orig) + + self.assertEqual(str(Signature.from_callable(lru.cache_info)), '()') + self.assertEqual(str(Signature.from_callable(lru.cache_clear)), '()') + @py_functools.lru_cache() def py_cached_func(x, y): return 3 * x + y -@c_functools.lru_cache() -def c_cached_func(x, y): - return 3 * x + y +if c_functools: + @c_functools.lru_cache() + def c_cached_func(x, y): + return 3 * x + y class TestLRUPy(TestLRU, unittest.TestCase): @@ -1741,18 +1943,20 @@ def cached_staticmeth(x, y): return 3 * x + y +@unittest.skipUnless(c_functools, 'requires the C _functools module') class TestLRUC(TestLRU, unittest.TestCase): - module = c_functools - cached_func = c_cached_func, + if c_functools: + module = c_functools + cached_func = c_cached_func, - @module.lru_cache() - def cached_meth(self, x, y): - return 3 * x + y + @module.lru_cache() + def cached_meth(self, x, y): + return 3 * x + y - @staticmethod - @module.lru_cache() - def cached_staticmeth(x, y): - return 3 * x + y + @staticmethod + @module.lru_cache() + def cached_staticmeth(x, y): + return 3 * x + y class TestSingleDispatch(unittest.TestCase): @@ -1867,7 +2071,7 @@ class D(collections.defaultdict): c.MutableSequence.register(D) bases = [c.MutableSequence, c.MutableMapping] for haystack in permutations(bases): - m = mro(D, bases) + m = mro(D, haystack) self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, collections.defaultdict, dict, c.MutableMapping, c.Mapping, c.Collection, c.Sized, c.Iterable, c.Container, @@ -2370,7 +2574,7 @@ def _(cls, arg): self.assertEqual(A.t(0.0).arg, "base") def test_abstractmethod_register(self): - class Abstract(abc.ABCMeta): + class Abstract(metaclass=abc.ABCMeta): @functools.singledispatchmethod @abc.abstractmethod @@ -2378,6 +2582,10 @@ def add(self, x, y): pass self.assertTrue(Abstract.add.__isabstractmethod__) + self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) + + with self.assertRaises(TypeError): + Abstract() def test_type_ann_register(self): class A: @@ -2396,6 +2604,183 @@ def _(self, arg: str): self.assertEqual(a.t(''), "str") self.assertEqual(a.t(0.0), "base") + def test_staticmethod_type_ann_register(self): + class A: + @functools.singledispatchmethod + @staticmethod + def t(arg): + return arg + @t.register + @staticmethod + def _(arg: int): + return isinstance(arg, int) + @t.register + @staticmethod + def _(arg: str): + return isinstance(arg, str) + a = A() + + self.assertTrue(A.t(0)) + self.assertTrue(A.t('')) + self.assertEqual(A.t(0.0), 0.0) + + def test_classmethod_type_ann_register(self): + class A: + def __init__(self, arg): + self.arg = arg + + @functools.singledispatchmethod + @classmethod + def t(cls, arg): + return cls("base") + @t.register + @classmethod + def _(cls, arg: int): + return cls("int") + @t.register + @classmethod + def _(cls, arg: str): + return cls("str") + + self.assertEqual(A.t(0).arg, "int") + self.assertEqual(A.t('').arg, "str") + self.assertEqual(A.t(0.0).arg, "base") + + def test_method_wrapping_attributes(self): + class A: + @functools.singledispatchmethod + def func(self, arg: int) -> str: + """My function docstring""" + return str(arg) + @functools.singledispatchmethod + @classmethod + def cls_func(cls, arg: int) -> str: + """My function docstring""" + return str(arg) + @functools.singledispatchmethod + @staticmethod + def static_func(arg: int) -> str: + """My function docstring""" + return str(arg) + + for meth in ( + A.func, + A().func, + A.cls_func, + A().cls_func, + A.static_func, + A().static_func + ): + with self.subTest(meth=meth): + self.assertEqual(meth.__doc__, 'My function docstring') + self.assertEqual(meth.__annotations__['arg'], int) + + self.assertEqual(A.func.__name__, 'func') + self.assertEqual(A().func.__name__, 'func') + self.assertEqual(A.cls_func.__name__, 'cls_func') + self.assertEqual(A().cls_func.__name__, 'cls_func') + self.assertEqual(A.static_func.__name__, 'static_func') + self.assertEqual(A().static_func.__name__, 'static_func') + + def test_double_wrapped_methods(self): + def classmethod_friendly_decorator(func): + wrapped = func.__func__ + @classmethod + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + return wrapped(*args, **kwargs) + return wrapper + + class WithoutSingleDispatch: + @classmethod + @contextlib.contextmanager + def cls_context_manager(cls, arg: int) -> str: + try: + yield str(arg) + finally: + return 'Done' + + @classmethod_friendly_decorator + @classmethod + def decorated_classmethod(cls, arg: int) -> str: + return str(arg) + + class WithSingleDispatch: + @functools.singledispatchmethod + @classmethod + @contextlib.contextmanager + def cls_context_manager(cls, arg: int) -> str: + """My function docstring""" + try: + yield str(arg) + finally: + return 'Done' + + @functools.singledispatchmethod + @classmethod_friendly_decorator + @classmethod + def decorated_classmethod(cls, arg: int) -> str: + """My function docstring""" + return str(arg) + + # These are sanity checks + # to test the test itself is working as expected + with WithoutSingleDispatch.cls_context_manager(5) as foo: + without_single_dispatch_foo = foo + + with WithSingleDispatch.cls_context_manager(5) as foo: + single_dispatch_foo = foo + + self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) + self.assertEqual(single_dispatch_foo, '5') + + self.assertEqual( + WithoutSingleDispatch.decorated_classmethod(5), + WithSingleDispatch.decorated_classmethod(5) + ) + + self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') + + # Behavioural checks now follow + for method_name in ('cls_context_manager', 'decorated_classmethod'): + with self.subTest(method=method_name): + self.assertEqual( + getattr(WithSingleDispatch, method_name).__name__, + getattr(WithoutSingleDispatch, method_name).__name__ + ) + + self.assertEqual( + getattr(WithSingleDispatch(), method_name).__name__, + getattr(WithoutSingleDispatch(), method_name).__name__ + ) + + for meth in ( + WithSingleDispatch.cls_context_manager, + WithSingleDispatch().cls_context_manager, + WithSingleDispatch.decorated_classmethod, + WithSingleDispatch().decorated_classmethod + ): + with self.subTest(meth=meth): + self.assertEqual(meth.__doc__, 'My function docstring') + self.assertEqual(meth.__annotations__['arg'], int) + + self.assertEqual( + WithSingleDispatch.cls_context_manager.__name__, + 'cls_context_manager' + ) + self.assertEqual( + WithSingleDispatch().cls_context_manager.__name__, + 'cls_context_manager' + ) + self.assertEqual( + WithSingleDispatch.decorated_classmethod.__name__, + 'decorated_classmethod' + ) + self.assertEqual( + WithSingleDispatch().decorated_classmethod.__name__, + 'decorated_classmethod' + ) + def test_invalid_registrations(self): msg_prefix = "Invalid first argument to `register()`: " msg_suffix = ( @@ -2435,6 +2820,17 @@ def _(arg: typing.Iterable[str]): 'typing.Iterable[str] is not a class.' )) + with self.assertRaises(TypeError) as exc: + @i.register + def _(arg: typing.Union[int, typing.Iterable[str]]): + return "Invalid Union" + self.assertTrue(str(exc.exception).startswith( + "Invalid annotation for 'arg'." + )) + self.assertTrue(str(exc.exception).endswith( + 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' + )) + def test_invalid_positional_argument(self): @functools.singledispatch def f(*args): @@ -2443,6 +2839,134 @@ def f(*args): with self.assertRaisesRegex(TypeError, msg): f() + def test_union(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | float): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "typing.Union") + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + self.assertEqual(f(1.0), "types.UnionType") + + def test_union_conflict(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | str): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "types.UnionType") # last one wins + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + + def test_union_None(self): + @functools.singledispatch + def typing_union(arg): + return "default" + + @typing_union.register + def _(arg: typing.Union[str, None]): + return "typing.Union" + + self.assertEqual(typing_union(1), "default") + self.assertEqual(typing_union(""), "typing.Union") + self.assertEqual(typing_union(None), "typing.Union") + + @functools.singledispatch + def types_union(arg): + return "default" + + @types_union.register + def _(arg: int | None): + return "types.UnionType" + + self.assertEqual(types_union(""), "default") + self.assertEqual(types_union(1), "types.UnionType") + self.assertEqual(types_union(None), "types.UnionType") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_register_genericalias(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int], lambda arg: "types.GenericAlias") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int], lambda arg: "typing.GenericAlias") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]") + + self.assertEqual(f([1]), "default") + self.assertEqual(f([1.0]), "default") + self.assertEqual(f(""), "default") + self.assertEqual(f(b""), "default") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_register_genericalias_decorator(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int]) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int]) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int] | str) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int] | str) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_register_genericalias_annotation(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: list[int]): + return "types.GenericAlias" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: typing.List[float]): + return "typing.GenericAlias" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: list[int] | str): + return "types.UnionType(types.GenericAlias)" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: typing.List[float] | bytes): + return "typing.Union[typing.GenericAlias]" + + self.assertEqual(f([1]), "default") + self.assertEqual(f([1.0]), "default") + self.assertEqual(f(""), "default") + self.assertEqual(f(b""), "default") + class CachedCostItem: _cost = 1 @@ -2469,21 +2993,6 @@ def get_cost(self): cached_cost = py_functools.cached_property(get_cost) -class CachedCostItemWait: - - def __init__(self, event): - self._cost = 1 - self.lock = py_functools.RLock() - self.event = event - - @py_functools.cached_property - def cost(self): - self.event.wait(1) - with self.lock: - self._cost += 1 - return self._cost - - class CachedCostItemWithSlots: __slots__ = ('_cost') @@ -2508,28 +3017,6 @@ def test_cached_attribute_name_differs_from_func_name(self): self.assertEqual(item.get_cost(), 4) self.assertEqual(item.cached_cost, 3) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_threaded(self): - go = threading.Event() - item = CachedCostItemWait(go) - - num_threads = 3 - - orig_si = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - try: - threads = [ - threading.Thread(target=lambda: item.cost) - for k in range(num_threads) - ] - with threading_helper.start_threads(threads): - go.set() - finally: - sys.setswitchinterval(orig_si) - - self.assertEqual(item.cost, 2) - # TODO: RUSTPYTHON @unittest.expectedFailure def test_object_with_slots(self): @@ -2559,7 +3046,7 @@ class MyClass(metaclass=MyMeta): @unittest.expectedFailure def test_reuse_different_names(self): """Disallow this case because decorated function a would not be cached.""" - with self.assertRaises(RuntimeError) as ctx: + with self.assertRaises(TypeError) as ctx: class ReusedCachedProperty: @py_functools.cached_property def a(self): @@ -2568,7 +3055,7 @@ def a(self): b = a self.assertEqual( - str(ctx.exception.__context__), + str(ctx.exception), str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) ) @@ -2614,6 +3101,25 @@ def test_access_from_class(self): def test_doc(self): self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") + def test_subclass_with___set__(self): + """Caching still works for a subclass defining __set__.""" + class readonly_cached_property(py_functools.cached_property): + def __set__(self, obj, value): + raise AttributeError("read only property") + + class Test: + def __init__(self, prop): + self._prop = prop + + @readonly_cached_property + def prop(self): + return self._prop + + t = Test(1) + self.assertEqual(t.prop, 1) + t._prop = 999 + self.assertEqual(t.prop, 1) + if __name__ == '__main__': unittest.main() From 682eef35923acf453d40578402653df275f8415b Mon Sep 17 00:00:00 2001 From: CPython developers <> Date: Fri, 20 Oct 2023 18:31:46 -0700 Subject: [PATCH 16/19] Update enum from Python 3.12 --- Lib/enum.py | 253 ++++++++++-------- Lib/test/test_enum.py | 603 ++++++++++++++++++++++++++++++------------ 2 files changed, 578 insertions(+), 278 deletions(-) diff --git a/Lib/enum.py b/Lib/enum.py index 625e9ea56a..c207dc234c 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -190,41 +190,48 @@ class property(DynamicClassAttribute): a corresponding enum member. """ + member = None + _attr_type = None + _cls_type = None + def __get__(self, instance, ownerclass=None): if instance is None: - try: - return ownerclass._member_map_[self.name] - except KeyError: + if self.member is not None: + return self.member + else: raise AttributeError( '%r has no attribute %r' % (ownerclass, self.name) ) - else: - if self.fget is None: - # look for a member by this name. - try: - return ownerclass._member_map_[self.name] - except KeyError: - raise AttributeError( - '%r has no attribute %r' % (ownerclass, self.name) - ) from None - else: - return self.fget(instance) + if self.fget is not None: + # use previous enum.property + return self.fget(instance) + elif self._attr_type == 'attr': + # look up previous attibute + return getattr(self._cls_type, self.name) + elif self._attr_type == 'desc': + # use previous descriptor + return getattr(instance._value_, self.name) + # look for a member by this name. + try: + return ownerclass._member_map_[self.name] + except KeyError: + raise AttributeError( + '%r has no attribute %r' % (ownerclass, self.name) + ) from None def __set__(self, instance, value): - if self.fset is None: - raise AttributeError( - " cannot set attribute %r" % (self.clsname, self.name) - ) - else: + if self.fset is not None: return self.fset(instance, value) + raise AttributeError( + " cannot set attribute %r" % (self.clsname, self.name) + ) def __delete__(self, instance): - if self.fdel is None: - raise AttributeError( - " cannot delete attribute %r" % (self.clsname, self.name) - ) - else: + if self.fdel is not None: return self.fdel(instance) + raise AttributeError( + " cannot delete attribute %r" % (self.clsname, self.name) + ) def __set_name__(self, ownerclass, name): self.name = name @@ -312,27 +319,38 @@ def __set_name__(self, enum_class, member_name): enum_class._member_names_.append(member_name) # if necessary, get redirect in place and then add it to _member_map_ found_descriptor = None + descriptor_type = None + class_type = None for base in enum_class.__mro__[1:]: - descriptor = base.__dict__.get(member_name) - if descriptor is not None: - if isinstance(descriptor, (property, DynamicClassAttribute)): - found_descriptor = descriptor + attr = base.__dict__.get(member_name) + if attr is not None: + if isinstance(attr, (property, DynamicClassAttribute)): + found_descriptor = attr + class_type = base + descriptor_type = 'enum' break - elif ( - hasattr(descriptor, 'fget') and - hasattr(descriptor, 'fset') and - hasattr(descriptor, 'fdel') - ): - found_descriptor = descriptor + elif _is_descriptor(attr): + found_descriptor = attr + descriptor_type = descriptor_type or 'desc' + class_type = class_type or base continue + else: + descriptor_type = 'attr' + class_type = base if found_descriptor: redirect = property() redirect.member = enum_member redirect.__set_name__(enum_class, member_name) - # earlier descriptor found; copy fget, fset, fdel to this one. - redirect.fget = found_descriptor.fget - redirect.fset = found_descriptor.fset - redirect.fdel = found_descriptor.fdel + if descriptor_type in ('enum','desc'): + # earlier descriptor found; copy fget, fset, fdel to this one. + redirect.fget = getattr(found_descriptor, 'fget', None) + redirect._get = getattr(found_descriptor, '__get__', None) + redirect.fset = getattr(found_descriptor, 'fset', None) + redirect._set = getattr(found_descriptor, '__set__', None) + redirect.fdel = getattr(found_descriptor, 'fdel', None) + redirect._del = getattr(found_descriptor, '__delete__', None) + redirect._attr_type = descriptor_type + redirect._cls_type = class_type setattr(enum_class, member_name, redirect) else: setattr(enum_class, member_name, enum_member) @@ -521,8 +539,13 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k # # adjust the sunders _order_ = classdict.pop('_order_', None) + _gnv = classdict.get('_generate_next_value_') + if _gnv is not None and type(_gnv) is not staticmethod: + _gnv = staticmethod(_gnv) # convert to normal dict classdict = dict(classdict.items()) + if _gnv is not None: + classdict['_generate_next_value_'] = _gnv # # data type of member and the controlling Enum class member_type, first_enum = metacls._get_mixins_(cls, bases) @@ -674,7 +697,7 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k 'member order does not match _order_:\n %r\n %r' % (enum_class._member_names_, _order_) ) - + # return enum_class def __bool__(cls): @@ -683,7 +706,7 @@ def __bool__(cls): """ return True - def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None): + def __call__(cls, value, names=None, *values, module=None, qualname=None, type=None, start=1, boundary=None): """ Either returns an existing member, or creates a new enum class. @@ -691,6 +714,8 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s to an enumeration member (i.e. Color(3)) and for the functional API (i.e. Color = Enum('Color', names='RED GREEN BLUE')). + The value lookup branch is chosen if the enum is final. + When used for the functional API: `value` will be the name of the new class. @@ -708,12 +733,20 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s `type`, if set, will be mixed in as the first base class. """ - if names is None: # simple value lookup + if cls._member_map_: + # simple value lookup if members exist + if names: + value = (value, names) + values return cls.__new__(cls, value) # otherwise, functional API: we're creating a new Enum type + if names is None and type is None: + # no body? no data-type? possibly wrong usage + raise TypeError( + f"{cls} has no members; specify `names=()` if you meant to create a new, empty, enum" + ) return cls._create_( - value, - names, + class_name=value, + names=names, module=module, qualname=qualname, type=type, @@ -721,26 +754,16 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s boundary=boundary, ) - def __contains__(cls, member): - """ - Return True if member is a member of this enum - raises TypeError if member is not an enum member + def __contains__(cls, value): + """Return True if `value` is in `cls`. - note: in 3.12 TypeError will no longer be raised, and True will also be - returned if member is the value of a member in this enum + `value` is in `cls` if: + 1) `value` is a member of `cls`, or + 2) `value` is the value of one of the `cls`'s members. """ - if not isinstance(member, Enum): - import warnings - warnings.warn( - "in 3.12 __contains__ will no longer raise TypeError, but will return True or\n" - "False depending on whether the value is a member or the value of a member", - DeprecationWarning, - stacklevel=2, - ) - raise TypeError( - "unsupported operand type(s) for 'in': '%s' and '%s'" % ( - type(member).__qualname__, cls.__class__.__qualname__)) - return isinstance(member, cls) and member._name_ in cls._member_map_ + if isinstance(value, cls): + return True + return value in cls._value2member_map_ or value in cls._unhashable_values_ def __delattr__(cls, attr): # nicer error message when someone tries to delete an attribute @@ -767,22 +790,6 @@ def __dir__(cls): # return whatever mixed-in data type has return sorted(set(dir(cls._member_type_)) | interesting) - def __getattr__(cls, name): - """ - Return the enum member matching `name` - - We use __getattr__ instead of descriptors or inserting into the enum - class' __dict__ in order to support `name` and `value` being both - properties for enum members (which live in the class' __dict__) and - enum members themselves. - """ - if _is_dunder(name): - raise AttributeError(name) - try: - return cls._member_map_[name] - except KeyError: - raise AttributeError(name) from None - def __getitem__(cls, name): """ Return the member matching `name`. @@ -863,6 +870,8 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s value = first_enum._generate_next_value_(name, start, count, last_values[:]) last_values.append(value) names.append((name, value)) + if names is None: + names = () # Here, names is either an iterable of (name, value) or a mapping. for item in names: @@ -872,13 +881,15 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s member_name, member_value = item classdict[member_name] = member_value - # TODO: replace the frame hack if a blessed way to know the calling - # module is ever developed if module is None: try: - module = sys._getframe(2).f_globals['__name__'] - except (AttributeError, ValueError, KeyError): - pass + module = sys._getframemodulename(2) + except AttributeError: + # Fall back on _getframe if _getframemodulename is missing + try: + module = sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError, KeyError): + pass if module is None: _make_class_unpicklable(classdict) else: @@ -946,9 +957,6 @@ def _get_mixins_(mcls, class_name, bases): """ if not bases: return object, Enum - - mcls._check_for_existing_members_(class_name, bases) - # ensure final parent class is an Enum derivative, find any concrete # data type, and check that Enum has no members first_enum = bases[-1] @@ -969,12 +977,20 @@ def _find_data_repr_(mcls, class_name, bases): return base._value_repr_ elif '__repr__' in base.__dict__: # this is our data repr - return base.__dict__['__repr__'] + # double-check if a dataclass with a default __repr__ + if ( + '__dataclass_fields__' in base.__dict__ + and '__dataclass_params__' in base.__dict__ + and base.__dict__['__dataclass_params__'].repr + ): + return _dataclass_repr + else: + return base.__dict__['__repr__'] return None @classmethod def _find_data_type_(mcls, class_name, bases): - # a datatype has a __new__ method + # a datatype has a __new__ method, or a __dataclass_fields__ attribute data_types = set() base_chain = set() for chain in bases: @@ -988,8 +1004,6 @@ def _find_data_type_(mcls, class_name, bases): data_types.add(base._member_type_) break elif '__new__' in base.__dict__ or '__dataclass_fields__' in base.__dict__: - if isinstance(base, EnumType): - continue data_types.add(candidate or base) break else: @@ -1061,20 +1075,20 @@ class Enum(metaclass=EnumType): Access them by: - - attribute access:: + - attribute access: - >>> Color.RED - + >>> Color.RED + - value lookup: - >>> Color(1) - + >>> Color(1) + - name lookup: - >>> Color['RED'] - + >>> Color['RED'] + Enumerations can be iterated over, and know how many members they have: @@ -1088,6 +1102,13 @@ class Enum(metaclass=EnumType): attributes -- see the documentation for details. """ + @classmethod + def __signature__(cls): + if cls._member_names_: + return '(*values)' + else: + return '(new_class_name, /, names, *, module=None, qualname=None, type=None, start=1, boundary=None)' + def __new__(cls, value): # all enum instances are actually created during class construction # without calling this method; this method is called by the metaclass' @@ -1107,6 +1128,11 @@ def __new__(cls, value): for member in cls._member_map_.values(): if member._value_ == value: return member + # still not found -- verify that members exist, in-case somebody got here mistakenly + # (such as via super when trying to override __new__) + if not cls._member_map_: + raise TypeError("%r has no members defined" % cls) + # # still not found -- try _missing_ hook try: exc = None @@ -1142,6 +1168,7 @@ def __new__(cls, value): def __init__(self, *args, **kwds): pass + @staticmethod def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -1236,10 +1263,10 @@ def __copy__(self): # enum.property is used to provide access to the `name` and # `value` attributes of enum members while keeping some measure of # protection from modification, while still allowing for an enumeration - # to have members named `name` and `value`. This works because enumeration - # members are not set directly on the enum class; they are kept in a - # separate structure, _member_map_, which is where enum.property looks for - # them + # to have members named `name` and `value`. This works because each + # instance of enum.property saves its companion member, which it returns + # on class lookup; on instance lookup it either executes a provided function + # or raises an AttributeError. @property def name(self): @@ -1290,6 +1317,7 @@ def __new__(cls, *values): member._value_ = value return member + @staticmethod def _generate_next_value_(name, start, count, last_values): """ Return the lower-cased version of the member name. @@ -1328,6 +1356,7 @@ class Flag(Enum, boundary=STRICT): _numeric_repr_ = repr + @staticmethod def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -1566,10 +1595,13 @@ def unique(enumeration): (enumeration, alias_details)) return enumeration -def _power_of_two(value): - if value < 1: - return False - return value == 2 ** _high_bit(value) +def _dataclass_repr(self): + dcf = self.__dataclass_fields__ + return ', '.join( + '%s=%r' % (k, getattr(self, k)) + for k in dcf.keys() + if dcf[k].repr + ) def global_enum_repr(self): """ @@ -1713,10 +1745,12 @@ def convert_class(cls): value = gnv(name, 1, len(member_names), gnv_last_values) if value in value2member_map: # an alias to an existing member + member = value2member_map[value] redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) - member_map[name] = value2member_map[value] + member_map[name] = member else: # create the member if use_args: @@ -1732,6 +1766,7 @@ def convert_class(cls): member.__objclass__ = enum_class member.__init__(value) redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) member_map[name] = member @@ -1760,10 +1795,12 @@ def convert_class(cls): value = value.value if value in value2member_map: # an alias to an existing member + member = value2member_map[value] redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) - member_map[name] = value2member_map[value] + member_map[name] = member else: # create the member if use_args: @@ -1780,6 +1817,7 @@ def convert_class(cls): member.__init__(value) member._sort_order_ = len(member_names) redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) member_map[name] = member @@ -1903,8 +1941,7 @@ def _test_simple_enum(checked_enum, simple_enum): ... RED = auto() ... GREEN = auto() ... BLUE = auto() - >>> # TODO: RUSTPYTHON - >>> # _test_simple_enum(CheckedColor, Color) + >>> _test_simple_enum(CheckedColor, Color) If differences are found, a :exc:`TypeError` is raised. """ diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 1c307e75ee..e09738ae27 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -20,7 +20,6 @@ from test import support from test.support import ALWAYS_EQ from test.support import threading_helper -from textwrap import dedent from datetime import timedelta python_version = sys.version_info[:2] @@ -237,11 +236,83 @@ class _EnumTests: values = None def setUp(self): - class BaseEnum(self.enum_type): + if self.__class__.__name__[-5:] == 'Class': + class BaseEnum(self.enum_type): + @enum.property + def first(self): + return '%s is first!' % self.name + class MainEnum(BaseEnum): + first = auto() + second = auto() + third = auto() + if issubclass(self.enum_type, Flag): + dupe = 3 + else: + dupe = third + self.MainEnum = MainEnum + # + class NewStrEnum(self.enum_type): + def __str__(self): + return self.name.upper() + first = auto() + self.NewStrEnum = NewStrEnum + # + class NewFormatEnum(self.enum_type): + def __format__(self, spec): + return self.name.upper() + first = auto() + self.NewFormatEnum = NewFormatEnum + # + class NewStrFormatEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + first = auto() + self.NewStrFormatEnum = NewStrFormatEnum + # + class NewBaseEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + self.NewBaseEnum = NewBaseEnum + class NewSubEnum(NewBaseEnum): + first = auto() + self.NewSubEnum = NewSubEnum + # + class LazyGNV(self.enum_type): + def _generate_next_value_(name, start, last, values): + pass + self.LazyGNV = LazyGNV + # + class BusyGNV(self.enum_type): + @staticmethod + def _generate_next_value_(name, start, last, values): + pass + self.BusyGNV = BusyGNV + # + self.is_flag = False + self.names = ['first', 'second', 'third'] + if issubclass(MainEnum, StrEnum): + self.values = self.names + elif MainEnum._member_type_ is str: + self.values = ['1', '2', '3'] + elif issubclass(self.enum_type, Flag): + self.values = [1, 2, 4] + self.is_flag = True + self.dupe2 = MainEnum(5) + else: + self.values = self.values or [1, 2, 3] + # + if not getattr(self, 'source_values', False): + self.source_values = self.values + elif self.__class__.__name__[-8:] == 'Function': @enum.property def first(self): return '%s is first!' % self.name - class MainEnum(BaseEnum): + BaseEnum = self.enum_type('BaseEnum', {'first':first}) + # first = auto() second = auto() third = auto() @@ -249,52 +320,58 @@ class MainEnum(BaseEnum): dupe = 3 else: dupe = third - self.MainEnum = MainEnum - # - class NewStrEnum(self.enum_type): + self.MainEnum = MainEnum = BaseEnum('MainEnum', dict(first=first, second=second, third=third, dupe=dupe)) + # def __str__(self): return self.name.upper() first = auto() - self.NewStrEnum = NewStrEnum - # - class NewFormatEnum(self.enum_type): + self.NewStrEnum = self.enum_type('NewStrEnum', (('first',first),('__str__',__str__))) + # def __format__(self, spec): return self.name.upper() first = auto() - self.NewFormatEnum = NewFormatEnum - # - class NewStrFormatEnum(self.enum_type): + self.NewFormatEnum = self.enum_type('NewFormatEnum', [('first',first),('__format__',__format__)]) + # def __str__(self): return self.name.title() def __format__(self, spec): return ''.join(reversed(self.name)) first = auto() - self.NewStrFormatEnum = NewStrFormatEnum - # - class NewBaseEnum(self.enum_type): + self.NewStrFormatEnum = self.enum_type('NewStrFormatEnum', dict(first=first, __format__=__format__, __str__=__str__)) + # def __str__(self): return self.name.title() def __format__(self, spec): return ''.join(reversed(self.name)) - class NewSubEnum(NewBaseEnum): - first = auto() - self.NewSubEnum = NewSubEnum - # - self.is_flag = False - self.names = ['first', 'second', 'third'] - if issubclass(MainEnum, StrEnum): - self.values = self.names - elif MainEnum._member_type_ is str: - self.values = ['1', '2', '3'] - elif issubclass(self.enum_type, Flag): - self.values = [1, 2, 4] - self.is_flag = True - self.dupe2 = MainEnum(5) + self.NewBaseEnum = self.enum_type('NewBaseEnum', dict(__format__=__format__, __str__=__str__)) + self.NewSubEnum = self.NewBaseEnum('NewSubEnum', 'first') + # + def _generate_next_value_(name, start, last, values): + pass + self.LazyGNV = self.enum_type('LazyGNV', {'_generate_next_value_':_generate_next_value_}) + # + @staticmethod + def _generate_next_value_(name, start, last, values): + pass + self.BusyGNV = self.enum_type('BusyGNV', {'_generate_next_value_':_generate_next_value_}) + # + self.is_flag = False + self.names = ['first', 'second', 'third'] + if issubclass(MainEnum, StrEnum): + self.values = self.names + elif MainEnum._member_type_ is str: + self.values = ['1', '2', '3'] + elif issubclass(self.enum_type, Flag): + self.values = [1, 2, 4] + self.is_flag = True + self.dupe2 = MainEnum(5) + else: + self.values = self.values or [1, 2, 3] + # + if not getattr(self, 'source_values', False): + self.source_values = self.values else: - self.values = self.values or [1, 2, 3] - # - if not getattr(self, 'source_values', False): - self.source_values = self.values + raise ValueError('unknown enum style: %r' % self.__class__.__name__) def assertFormatIsValue(self, spec, member): self.assertEqual(spec.format(member), spec.format(member.value)) @@ -322,6 +399,17 @@ def spam(cls): with self.assertRaises(AttributeError): del Season.SPRING.name + def test_bad_new_super(self): + with self.assertRaisesRegex( + TypeError, + 'has no members defined', + ): + class BadSuper(self.enum_type): + def __new__(cls, value): + obj = super().__new__(cls, value) + return obj + failed = 1 + def test_basics(self): TE = self.MainEnum if self.is_flag: @@ -373,19 +461,12 @@ def test_changing_member_fails(self): with self.assertRaises(AttributeError): self.MainEnum.second = 'really first' - @unittest.skipIf( - python_version >= (3, 12), - '__contains__ now returns True/False for all inputs', - ) - def test_contains_er(self): + def test_contains_tf(self): MainEnum = self.MainEnum - self.assertIn(MainEnum.third, MainEnum) - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - self.source_values[1] in MainEnum - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'first' in MainEnum + self.assertIn(MainEnum.first, MainEnum) + self.assertTrue(self.values[0] in MainEnum) + if type(self) not in (TestStrEnumClass, TestStrEnumFunction): + self.assertFalse('first' in MainEnum) val = MainEnum.dupe self.assertIn(val, MainEnum) # @@ -393,23 +474,43 @@ class OtherEnum(Enum): one = auto() two = auto() self.assertNotIn(OtherEnum.two, MainEnum) - - @unittest.skipIf( - python_version < (3, 12), - '__contains__ works only with enum memmbers before 3.12', - ) - def test_contains_tf(self): + # + if MainEnum._member_type_ is object: + # enums without mixed data types will always be False + class NotEqualEnum(self.enum_type): + this = self.source_values[0] + that = self.source_values[1] + self.assertNotIn(NotEqualEnum.this, MainEnum) + self.assertNotIn(NotEqualEnum.that, MainEnum) + else: + # enums with mixed data types may be True + class EqualEnum(self.enum_type): + this = self.source_values[0] + that = self.source_values[1] + self.assertIn(EqualEnum.this, MainEnum) + self.assertIn(EqualEnum.that, MainEnum) + + def test_contains_same_name_diff_enum_diff_values(self): MainEnum = self.MainEnum - self.assertIn(MainEnum.first, MainEnum) - self.assertTrue(self.source_values[0] in MainEnum) - self.assertFalse('first' in MainEnum) - val = MainEnum.dupe - self.assertIn(val, MainEnum) # class OtherEnum(Enum): - one = auto() - two = auto() - self.assertNotIn(OtherEnum.two, MainEnum) + first = "brand" + second = "new" + third = "values" + # + self.assertIn(MainEnum.first, MainEnum) + self.assertIn(MainEnum.second, MainEnum) + self.assertIn(MainEnum.third, MainEnum) + self.assertNotIn(MainEnum.first, OtherEnum) + self.assertNotIn(MainEnum.second, OtherEnum) + self.assertNotIn(MainEnum.third, OtherEnum) + # + self.assertIn(OtherEnum.first, OtherEnum) + self.assertIn(OtherEnum.second, OtherEnum) + self.assertIn(OtherEnum.third, OtherEnum) + self.assertNotIn(OtherEnum.first, MainEnum) + self.assertNotIn(OtherEnum.second, MainEnum) + self.assertNotIn(OtherEnum.third, MainEnum) def test_dir_on_class(self): TE = self.MainEnum @@ -459,10 +560,20 @@ class SubEnum(SuperEnum): self.assertTrue('description' not in dir(SubEnum)) self.assertTrue('description' in dir(SubEnum.sample), dir(SubEnum.sample)) + def test_empty_enum_has_no_values(self): + with self.assertRaisesRegex(TypeError, "<.... 'NewBaseEnum'> has no members"): + self.NewBaseEnum(7) + def test_enum_in_enum_out(self): Main = self.MainEnum self.assertIs(Main(Main.first), Main.first) + def test_gnv_is_static(self): + lazy = self.LazyGNV + busy = self.BusyGNV + self.assertTrue(type(lazy.__dict__['_generate_next_value_']) is staticmethod) + self.assertTrue(type(busy.__dict__['_generate_next_value_']) is staticmethod) + def test_hash(self): MainEnum = self.MainEnum mapping = {} @@ -499,7 +610,7 @@ def __repr__(self): def test_overridden_str(self): # TODO: RUSTPYTHON, format(NS.first) does not use __str__ - if isinstance(self, TestIntFlag) or isinstance(self, TestIntEnum) or isinstance(self, TestMinimalFloat): + if self.__class__ in (TestIntFlagFunction, TestIntFlagClass, TestIntEnumFunction, TestIntEnumClass, TestMinimalFloatFunction, TestMinimalFloatClass): self.skipTest("format(NS.first) does not use __str__") NS = self.NewStrEnum self.assertEqual(str(NS.first), NS.first.name.upper()) @@ -883,80 +994,192 @@ class OpenXYZ(self.enum_type): self.assertTrue(~OpenXYZ(0), (X|Y|Z)) -class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase): +class TestPlainEnumClass(_EnumTests, _PlainOutputTests, unittest.TestCase): enum_type = Enum -class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): +class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase): + enum_type = Enum + + +class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): + enum_type = Flag + + +class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): enum_type = Flag -class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestIntEnumClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = IntEnum + # + def test_shadowed_attr(self): + class Number(IntEnum): + divisor = 1 + numerator = 2 + # + self.assertEqual(Number.divisor.numerator, 1) + self.assertIs(Number.numerator.divisor, Number.divisor) + + +class TestIntEnumFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): enum_type = IntEnum + # + def test_shadowed_attr(self): + Number = IntEnum('Number', ('divisor', 'numerator')) + # + self.assertEqual(Number.divisor.numerator, 1) + self.assertIs(Number.numerator.divisor, Number.divisor) + + +class TestStrEnumClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = StrEnum + # + def test_shadowed_attr(self): + class Book(StrEnum): + author = 'author' + title = 'title' + # + self.assertEqual(Book.author.title(), 'Author') + self.assertEqual(Book.title.title(), 'Title') + self.assertIs(Book.title.author, Book.author) -class TestStrEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestStrEnumFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): enum_type = StrEnum + # + def test_shadowed_attr(self): + Book = StrEnum('Book', ('author', 'title')) + # + self.assertEqual(Book.author.title(), 'Author') + self.assertEqual(Book.title.title(), 'Title') + self.assertIs(Book.title.author, Book.author) + + +class TestIntFlagClass(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): + enum_type = IntFlag -class TestIntFlag(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): +class TestIntFlagFunction(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): enum_type = IntFlag -class TestMixedInt(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMixedIntClass(_EnumTests, _MixedOutputTests, unittest.TestCase): class enum_type(int, Enum): pass -class TestMixedStr(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMixedIntFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + enum_type = Enum('enum_type', type=int) + + +class TestMixedStrClass(_EnumTests, _MixedOutputTests, unittest.TestCase): class enum_type(str, Enum): pass -class TestMixedIntFlag(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): +class TestMixedStrFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + enum_type = Enum('enum_type', type=str) + + +class TestMixedIntFlagClass(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): class enum_type(int, Flag): pass -class TestMixedDate(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMixedIntFlagFunction(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): + enum_type = Flag('enum_type', type=int) + +class TestMixedDateClass(_EnumTests, _MixedOutputTests, unittest.TestCase): + # values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)] source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] - + # class enum_type(date, Enum): + @staticmethod def _generate_next_value_(name, start, count, last_values): values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] return values[count] -class TestMinimalDate(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestMixedDateFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + # + values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)] + source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + # + # staticmethod decorator will be added by EnumType if not present + def _generate_next_value_(name, start, count, last_values): + values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + return values[count] + # + enum_type = Enum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=date) + +class TestMinimalDateClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)] source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] - + # class enum_type(date, ReprEnum): + # staticmethod decorator will be added by EnumType if absent def _generate_next_value_(name, start, count, last_values): values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] return values[count] -class TestMixedFloat(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMinimalDateFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # + values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)] + source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + # + @staticmethod + def _generate_next_value_(name, start, count, last_values): + values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + return values[count] + # + enum_type = ReprEnum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=date) - values = [1.1, 2.2, 3.3] +class TestMixedFloatClass(_EnumTests, _MixedOutputTests, unittest.TestCase): + # + values = [1.1, 2.2, 3.3] + # class enum_type(float, Enum): def _generate_next_value_(name, start, count, last_values): values = [1.1, 2.2, 3.3] return values[count] -class TestMinimalFloat(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestMixedFloatFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + # + values = [1.1, 2.2, 3.3] + # + def _generate_next_value_(name, start, count, last_values): + values = [1.1, 2.2, 3.3] + return values[count] + # + enum_type = Enum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=float) - values = [4.4, 5.5, 6.6] +class TestMinimalFloatClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # + values = [4.4, 5.5, 6.6] + # class enum_type(float, ReprEnum): def _generate_next_value_(name, start, count, last_values): values = [4.4, 5.5, 6.6] return values[count] +class TestMinimalFloatFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # + values = [4.4, 5.5, 6.6] + # + def _generate_next_value_(name, start, count, last_values): + values = [4.4, 5.5, 6.6] + return values[count] + # + enum_type = ReprEnum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=float) + + class TestSpecial(unittest.TestCase): """ various operations that are not attributable to every possible enum @@ -1244,6 +1467,28 @@ class Huh(Enum): self.assertEqual(Huh.name.name, 'name') self.assertEqual(Huh.name.value, 1) + def test_contains_name_and_value_overlap(self): + class IntEnum1(IntEnum): + X = 1 + class IntEnum2(IntEnum): + X = 1 + class IntEnum3(IntEnum): + X = 2 + class IntEnum4(IntEnum): + Y = 1 + self.assertIn(IntEnum1.X, IntEnum1) + self.assertIn(IntEnum1.X, IntEnum2) + self.assertNotIn(IntEnum1.X, IntEnum3) + self.assertIn(IntEnum1.X, IntEnum4) + + def test_contains_different_types_same_members(self): + class IntEnum1(IntEnum): + X = 1 + class IntFlag1(IntFlag): + X = 1 + self.assertIn(IntEnum1.X, IntFlag1) + self.assertIn(IntFlag1.X, IntEnum1) + def test_inherited_data_type(self): class HexInt(int): __qualname__ = 'HexInt' @@ -1262,7 +1507,6 @@ class MyEnum(HexInt, enum.Enum): # class SillyInt(HexInt): __qualname__ = 'SillyInt' - pass class MyOtherEnum(SillyInt, enum.Enum): __qualname__ = 'MyOtherEnum' D = 4 @@ -1396,6 +1640,21 @@ def test_programmatic_function_type_from_subclass_with_start(self): self.assertIn(e, MinorEnum) self.assertIs(type(e), MinorEnum) + def test_programmatic_function_is_value_call(self): + class TwoPart(Enum): + ONE = 1, 1.0 + TWO = 2, 2.0 + THREE = 3, 3.0 + self.assertRaisesRegex(ValueError, '1 is not a valid .*TwoPart', TwoPart, 1) + self.assertIs(TwoPart((1, 1.0)), TwoPart.ONE) + self.assertIs(TwoPart(1, 1.0), TwoPart.ONE) + class ThreePart(Enum): + ONE = 1, 1.0, 'one' + TWO = 2, 2.0, 'two' + THREE = 3, 3.0, 'three' + self.assertIs(ThreePart((3, 3.0, 'three')), ThreePart.THREE) + self.assertIs(ThreePart(3, 3.0, 'three'), ThreePart.THREE) + # TODO: RUSTPYTHON, AssertionError: is not @unittest.expectedFailure def test_intenum_from_bytes(self): @@ -1539,7 +1798,7 @@ class MoreColor(Color): class EvenMoreColor(Color, IntEnum): chartruese = 7 # - with self.assertRaisesRegex(TypeError, " cannot extend "): + with self.assertRaisesRegex(ValueError, r"\(.Foo., \(.pink., .black.\)\) is not a valid .*Color"): Color('Foo', ('pink', 'black')) def test_exclude_methods(self): @@ -2733,14 +2992,15 @@ class Private(Enum): self.assertEqual(Private._Private__corporal, 'Radar') self.assertEqual(Private._Private__major_, 'Hoolihan') - @unittest.skip("Accessing all values retained for performance reasons, see GH-93910") - def test_exception_for_member_from_member_access(self): - with self.assertRaisesRegex(AttributeError, " member has no attribute .NO."): - class Di(Enum): - YES = 1 - NO = 0 - nope = Di.YES.NO - + def test_member_from_member_access(self): + class Di(Enum): + YES = 1 + NO = 0 + name = 3 + warn = Di.YES.NO + self.assertIs(warn, Di.NO) + self.assertIs(Di.name, Di['name']) + self.assertEqual(Di.name.name, 'name') def test_dynamic_members_with_static_methods(self): # @@ -2771,20 +3031,69 @@ def upper(self): def test_repr_with_dataclass(self): "ensure dataclass-mixin has correct repr()" - from dataclasses import dataclass - @dataclass + # + # check overridden dataclass __repr__ is used + # + from dataclasses import dataclass, field + @dataclass(repr=False) class Foo: __qualname__ = 'Foo' a: int + def __repr__(self): + return 'ha hah!' class Entries(Foo, Enum): ENTRY1 = 1 + self.assertEqual(repr(Entries.ENTRY1), '') + self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) self.assertTrue(isinstance(Entries.ENTRY1, Foo)) self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_) - self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) - self.assertEqual(repr(Entries.ENTRY1), '') - - def test_repr_with_init_data_type_mixin(self): - # non-data_type is a mixin that doesn't define __new__ + # + # check auto-generated dataclass __repr__ is not used + # + @dataclass + class CreatureDataMixin: + __qualname__ = 'CreatureDataMixin' + size: str + legs: int + tail: bool = field(repr=False, default=True) + class Creature(CreatureDataMixin, Enum): + __qualname__ = 'Creature' + BEETLE = ('small', 6) + DOG = ('medium', 4) + self.assertEqual(repr(Creature.DOG), "") + # + # check inherited repr used + # + class Huh: + def __repr__(self): + return 'inherited' + @dataclass(repr=False) + class CreatureDataMixin(Huh): + __qualname__ = 'CreatureDataMixin' + size: str + legs: int + tail: bool = field(repr=False, default=True) + class Creature(CreatureDataMixin, Enum): + __qualname__ = 'Creature' + BEETLE = ('small', 6) + DOG = ('medium', 4) + self.assertEqual(repr(Creature.DOG), "") + # + # check default object.__repr__ used if nothing provided + # + @dataclass(repr=False) + class CreatureDataMixin: + __qualname__ = 'CreatureDataMixin' + size: str + legs: int + tail: bool = field(repr=False, default=True) + class Creature(CreatureDataMixin, Enum): + __qualname__ = 'Creature' + BEETLE = ('small', 6) + DOG = ('medium', 4) + self.assertRegex(repr(Creature.DOG), "") + + def test_repr_with_init_mixin(self): class Foo: def __init__(self, a): self.a = a @@ -2795,7 +3104,7 @@ class Entries(Foo, Enum): # self.assertEqual(repr(Entries.ENTRY1), 'Foo(a=1)') - def test_repr_and_str_with_non_data_type_mixin(self): + def test_repr_and_str_with_no_init_mixin(self): # non-data_type is a mixin that doesn't define __new__ class Foo: def __repr__(self): @@ -3250,32 +3559,6 @@ def test_pickle(self): test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.BIG) - @unittest.skipIf( - python_version >= (3, 12), - '__contains__ now returns True/False for all inputs', - ) - def test_contains_er(self): - Open = self.Open - Color = self.Color - self.assertFalse(Color.BLACK in Open) - self.assertFalse(Open.RO in Color) - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'BLACK' in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'RO' in Open - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 1 in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 1 in Open - - @unittest.skipIf( - python_version < (3, 12), - '__contains__ only works with enum memmbers before 3.12', - ) def test_contains_tf(self): Open = self.Open Color = self.Color @@ -3283,6 +3566,8 @@ def test_contains_tf(self): self.assertFalse(Open.RO in Color) self.assertFalse('BLACK' in Color) self.assertFalse('RO' in Open) + self.assertTrue(Color.BLACK in Color) + self.assertTrue(Open.RO in Open) self.assertTrue(1 in Color) self.assertTrue(1 in Open) @@ -3449,9 +3734,8 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.wait_threads_exit(): - with threading_helper.start_threads(threads): - pass + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, @@ -3827,41 +4111,11 @@ def test_programatic_function_from_empty_tuple(self): self.assertEqual(len(lst), len(Thing)) self.assertEqual(len(Thing), 0, Thing) - @unittest.skipIf( - python_version >= (3, 12), - '__contains__ now returns True/False for all inputs', - ) - def test_contains_er(self): - Open = self.Open - Color = self.Color - self.assertTrue(Color.GREEN in Color) - self.assertTrue(Open.RW in Open) - self.assertFalse(Color.GREEN in Open) - self.assertFalse(Open.RW in Color) - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'GREEN' in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'RW' in Open - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 2 in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 2 in Open - - @unittest.skipIf( - python_version < (3, 12), - '__contains__ only works with enum memmbers before 3.12', - ) def test_contains_tf(self): Open = self.Open Color = self.Color self.assertTrue(Color.GREEN in Color) self.assertTrue(Open.RW in Open) - self.assertTrue(Color.GREEN in Open) - self.assertTrue(Open.RW in Color) self.assertFalse('GREEN' in Color) self.assertFalse('RW' in Open) self.assertTrue(2 in Color) @@ -3967,6 +4221,7 @@ class Color(StrMixin, AllMixin, IntFlag): self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') + @unittest.skip("TODO: RUSTPYTHON; flaky test") @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_unique_composite(self): @@ -3998,9 +4253,8 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.wait_threads_exit(): - with threading_helper.start_threads(threads): - pass + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, @@ -4651,6 +4905,29 @@ def test_inspect_classify_class_attrs(self): if failed: self.fail("result does not equal expected, see print above") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_inspect_signatures(self): + from inspect import signature, Signature, Parameter + self.assertEqual( + signature(Enum), + Signature([ + Parameter('new_class_name', Parameter.POSITIONAL_ONLY), + Parameter('names', Parameter.POSITIONAL_OR_KEYWORD), + Parameter('module', Parameter.KEYWORD_ONLY, default=None), + Parameter('qualname', Parameter.KEYWORD_ONLY, default=None), + Parameter('type', Parameter.KEYWORD_ONLY, default=None), + Parameter('start', Parameter.KEYWORD_ONLY, default=1), + Parameter('boundary', Parameter.KEYWORD_ONLY, default=None), + ]), + ) + self.assertEqual( + signature(enum.FlagBoundary), + Signature([ + Parameter('values', Parameter.VAR_POSITIONAL), + ]), + ) + # TODO: RUSTPYTHON, len is often/always > 256 @unittest.expectedFailure def test_test_simple_enum(self): @@ -4756,11 +5033,6 @@ class Quadruple(Enum): COMPLEX_A = 2j COMPLEX_B = 3j -class _ModuleWrapper: - """We use this class as a namespace for swapping modules.""" - def __init__(self, module): - self.__dict__.update(module.__dict__) - class TestConvert(unittest.TestCase): def tearDown(self): # Reset the module-level test variables to their original integer @@ -4800,12 +5072,6 @@ def test_convert_int(self): self.assertEqual(test_type.CONVERT_TEST_NAME_D, 5) self.assertEqual(test_type.CONVERT_TEST_NAME_E, 5) # Ensure that test_type only picked up names matching the filter. - int_dir = dir(int) + [ - 'CONVERT_TEST_NAME_A', 'CONVERT_TEST_NAME_B', 'CONVERT_TEST_NAME_C', - 'CONVERT_TEST_NAME_D', 'CONVERT_TEST_NAME_E', 'CONVERT_TEST_NAME_F', - 'CONVERT_TEST_SIGABRT', 'CONVERT_TEST_SIGIOT', - 'CONVERT_TEST_EIO', 'CONVERT_TEST_EBUS', - ] extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] self.assertEqual( @@ -4847,7 +5113,6 @@ def test_convert_str(self): self.assertEqual(test_type.CONVERT_STR_TEST_1, 'hello') self.assertEqual(test_type.CONVERT_STR_TEST_2, 'goodbye') # Ensure that test_type only picked up names matching the filter. - str_dir = dir(str) + ['CONVERT_STR_TEST_1', 'CONVERT_STR_TEST_2'] extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] self.assertEqual( @@ -4915,8 +5180,6 @@ def member_dir(member): allowed.add(name) return sorted(allowed) -missing = object() - if __name__ == '__main__': unittest.main() From 6e721336f7a273999e2c3bc173f2043d80700a9a Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 19:57:51 -0700 Subject: [PATCH 17/19] enum --- Lib/enum.py | 3 ++- Lib/test/test_enum.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Lib/enum.py b/Lib/enum.py index c207dc234c..7cffb71863 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1941,7 +1941,8 @@ def _test_simple_enum(checked_enum, simple_enum): ... RED = auto() ... GREEN = auto() ... BLUE = auto() - >>> _test_simple_enum(CheckedColor, Color) + ... # TODO: RUSTPYTHON + >>> _test_simple_enum(CheckedColor, Color) # doctest: +SKIP If differences are found, a :exc:`TypeError` is raised. """ diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index e09738ae27..b36c368f1f 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -4794,6 +4794,8 @@ class Color(Enum): MAGENTA = 2 YELLOW = 3 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_pydoc(self): # indirectly test __objclass__ if StrEnum.__doc__ is None: From b9e5ec00b617acd0fa9d5ad3c2d01cee460953de Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 18:48:09 -0700 Subject: [PATCH 18/19] Doc format changed --- Lib/test/test_enum.py | 109 +++++++++++++++--------------------------- 1 file changed, 39 insertions(+), 70 deletions(-) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index b36c368f1f..3989b7d674 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -4671,118 +4671,87 @@ class TestEnumTypeSubclassing(unittest.TestCase): Help on class Color in module %s: class Color(enum.Enum) - | Create a collection of name/value pairs. - |\x20\x20 - | Example enumeration: - |\x20\x20 - | >>> class Color(Enum): - | ... RED = 1 - | ... BLUE = 2 - | ... GREEN = 3 - |\x20\x20 - | Access them by: - |\x20\x20 - | - attribute access:: - |\x20\x20 - | >>> Color.RED - | - |\x20\x20 - | - value lookup: - |\x20\x20 - | >>> Color(1) - | - |\x20\x20 - | - name lookup: - |\x20\x20 - | >>> Color['RED'] - | - |\x20\x20 - | Enumerations can be iterated over, and know how many members they have: - |\x20\x20 - | >>> len(Color) - | 3 - |\x20\x20 - | >>> list(Color) - | [, , ] - |\x20\x20 - | Methods can be added to enumerations, and members can have their own - | attributes -- see the documentation for details. - |\x20\x20 + | Color(*values) + | | Method resolution order: | Color | enum.Enum | builtins.object - |\x20\x20 + | | Data and other attributes defined here: - |\x20\x20 + | | CYAN = - |\x20\x20 + | | MAGENTA = - |\x20\x20 + | | YELLOW = - |\x20\x20 + | | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: - |\x20\x20 + | | name | The name of the Enum member. - |\x20\x20 + | | value | The value of the Enum member. - |\x20\x20 + | | ---------------------------------------------------------------------- | Methods inherited from enum.EnumType: - |\x20\x20 - | __contains__(member) from enum.EnumType - | Return True if member is a member of this enum - | raises TypeError if member is not an enum member - |\x20\x20\x20\x20\x20\x20 - | note: in 3.12 TypeError will no longer be raised, and True will also be - | returned if member is the value of a member in this enum - |\x20\x20 + | + | __contains__(value) from enum.EnumType + | Return True if `value` is in `cls`. + | + | `value` is in `cls` if: + | 1) `value` is a member of `cls`, or + | 2) `value` is the value of one of the `cls`'s members. + | | __getitem__(name) from enum.EnumType | Return the member matching `name`. - |\x20\x20 + | | __iter__() from enum.EnumType | Return members in definition order. - |\x20\x20 + | | __len__() from enum.EnumType | Return the number of members (no aliases) - |\x20\x20 + | | ---------------------------------------------------------------------- - | Data descriptors inherited from enum.EnumType: - |\x20\x20 - | __members__""" + | Readonly properties inherited from enum.EnumType: + | + | __members__ + | Returns a mapping of member name->value. + | + | This mapping lists all enum members, including aliases. Note that this + | is a read-only view of the internal mapping.""" expected_help_output_without_docs = """\ Help on class Color in module %s: class Color(enum.Enum) - | Color(value, names=None, *, module=None, qualname=None, type=None, start=1) - |\x20\x20 + | Color(*values) + | | Method resolution order: | Color | enum.Enum | builtins.object - |\x20\x20 + | | Data and other attributes defined here: - |\x20\x20 + | | YELLOW = - |\x20\x20 + | | MAGENTA = - |\x20\x20 + | | CYAN = - |\x20\x20 + | | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: - |\x20\x20 + | | name - |\x20\x20 + | | value - |\x20\x20 + | | ---------------------------------------------------------------------- | Data descriptors inherited from enum.EnumType: - |\x20\x20 + | | __members__""" class TestStdLib(unittest.TestCase): From 6e699d93fe9ed1f16921ef3ab4a83fc0a72c5473 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 20 Oct 2023 18:52:39 -0700 Subject: [PATCH 19/19] Update test_module from CPython --- .../__init__.py} | 55 ++++++++----------- Lib/test/test_module/bad_getattr.py | 4 ++ Lib/test/test_module/bad_getattr2.py | 7 +++ Lib/test/test_module/bad_getattr3.py | 5 ++ Lib/test/test_module/good_getattr.py | 11 ++++ 5 files changed, 50 insertions(+), 32 deletions(-) rename Lib/test/{test_module.py => test_module/__init__.py} (90%) create mode 100644 Lib/test/test_module/bad_getattr.py create mode 100644 Lib/test/test_module/bad_getattr2.py create mode 100644 Lib/test/test_module/bad_getattr3.py create mode 100644 Lib/test/test_module/good_getattr.py diff --git a/Lib/test/test_module.py b/Lib/test/test_module/__init__.py similarity index 90% rename from Lib/test/test_module.py rename to Lib/test/test_module/__init__.py index b921fc6f4e..d8a0ba0803 100644 --- a/Lib/test/test_module.py +++ b/Lib/test/test_module/__init__.py @@ -8,17 +8,16 @@ import sys ModuleType = type(sys) + class FullLoader: - @classmethod - def module_repr(cls, m): - return "".format(m.__name__) + pass + class BareLoader: pass class ModuleTests(unittest.TestCase): - def test_uninitialized(self): # An uninitialized module has no __dict__ or __name__, # and __doc__ is None @@ -128,11 +127,9 @@ def test_weakref(self): gc_collect() self.assertIs(wr(), None) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_module_getattr(self): - import test.good_getattr as gga - from test.good_getattr import test + import test.test_module.good_getattr as gga + from test.test_module.good_getattr import test self.assertEqual(test, "There is test") self.assertEqual(gga.x, 1) self.assertEqual(gga.y, 2) @@ -140,54 +137,50 @@ def test_module_getattr(self): "Deprecated, use whatever instead"): gga.yolo self.assertEqual(gga.whatever, "There is whatever") - del sys.modules['test.good_getattr'] + del sys.modules['test.test_module.good_getattr'] - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_module_getattr_errors(self): - import test.bad_getattr as bga - from test import bad_getattr2 + import test.test_module.bad_getattr as bga + from test.test_module import bad_getattr2 self.assertEqual(bga.x, 1) self.assertEqual(bad_getattr2.x, 1) with self.assertRaises(TypeError): bga.nope with self.assertRaises(TypeError): bad_getattr2.nope - del sys.modules['test.bad_getattr'] - if 'test.bad_getattr2' in sys.modules: - del sys.modules['test.bad_getattr2'] + del sys.modules['test.test_module.bad_getattr'] + if 'test.test_module.bad_getattr2' in sys.modules: + del sys.modules['test.test_module.bad_getattr2'] # TODO: RUSTPYTHON @unittest.expectedFailure def test_module_dir(self): - import test.good_getattr as gga + import test.test_module.good_getattr as gga self.assertEqual(dir(gga), ['a', 'b', 'c']) - del sys.modules['test.good_getattr'] + del sys.modules['test.test_module.good_getattr'] # TODO: RUSTPYTHON @unittest.expectedFailure def test_module_dir_errors(self): - import test.bad_getattr as bga - from test import bad_getattr2 + import test.test_module.bad_getattr as bga + from test.test_module import bad_getattr2 with self.assertRaises(TypeError): dir(bga) with self.assertRaises(TypeError): dir(bad_getattr2) - del sys.modules['test.bad_getattr'] - if 'test.bad_getattr2' in sys.modules: - del sys.modules['test.bad_getattr2'] + del sys.modules['test.test_module.bad_getattr'] + if 'test.test_module.bad_getattr2' in sys.modules: + del sys.modules['test.test_module.bad_getattr2'] - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_module_getattr_tricky(self): - from test import bad_getattr3 + from test.test_module import bad_getattr3 # these lookups should not crash with self.assertRaises(AttributeError): bad_getattr3.one with self.assertRaises(AttributeError): bad_getattr3.delgetattr - if 'test.bad_getattr3' in sys.modules: - del sys.modules['test.bad_getattr3'] + if 'test.test_module.bad_getattr3' in sys.modules: + del sys.modules['test.test_module.bad_getattr3'] def test_module_repr_minimal(self): # reprs when modules have no __file__, __name__, or __loader__ @@ -249,10 +242,9 @@ def test_module_repr_with_full_loader(self): # Yes, a class not an instance. m.__loader__ = FullLoader self.assertEqual( - repr(m), "") + repr(m), f")>") def test_module_repr_with_bare_loader_and_filename(self): - # Because the loader has no module_repr(), use the file name. m = ModuleType('foo') # Yes, a class not an instance. m.__loader__ = BareLoader @@ -260,12 +252,11 @@ def test_module_repr_with_bare_loader_and_filename(self): self.assertEqual(repr(m), "") def test_module_repr_with_full_loader_and_filename(self): - # Even though the module has an __file__, use __loader__.module_repr() m = ModuleType('foo') # Yes, a class not an instance. m.__loader__ = FullLoader m.__file__ = '/tmp/foo.py' - self.assertEqual(repr(m), "") + self.assertEqual(repr(m), "") def test_module_repr_builtin(self): self.assertEqual(repr(sys), "") diff --git a/Lib/test/test_module/bad_getattr.py b/Lib/test/test_module/bad_getattr.py new file mode 100644 index 0000000000..16f901b13b --- /dev/null +++ b/Lib/test/test_module/bad_getattr.py @@ -0,0 +1,4 @@ +x = 1 + +__getattr__ = "Surprise!" +__dir__ = "Surprise again!" diff --git a/Lib/test/test_module/bad_getattr2.py b/Lib/test/test_module/bad_getattr2.py new file mode 100644 index 0000000000..0a52a53b54 --- /dev/null +++ b/Lib/test/test_module/bad_getattr2.py @@ -0,0 +1,7 @@ +def __getattr__(): + "Bad one" + +x = 1 + +def __dir__(bad_sig): + return [] diff --git a/Lib/test/test_module/bad_getattr3.py b/Lib/test/test_module/bad_getattr3.py new file mode 100644 index 0000000000..0d5f9266c7 --- /dev/null +++ b/Lib/test/test_module/bad_getattr3.py @@ -0,0 +1,5 @@ +def __getattr__(name): + if name != 'delgetattr': + raise AttributeError + del globals()['__getattr__'] + raise AttributeError diff --git a/Lib/test/test_module/good_getattr.py b/Lib/test/test_module/good_getattr.py new file mode 100644 index 0000000000..7d27de6262 --- /dev/null +++ b/Lib/test/test_module/good_getattr.py @@ -0,0 +1,11 @@ +x = 1 + +def __dir__(): + return ['a', 'b', 'c'] + +def __getattr__(name): + if name == "yolo": + raise AttributeError("Deprecated, use whatever instead") + return f"There is {name}" + +y = 2