diff --git a/Include/cpython/pystate.h b/Include/cpython/pystate.h index 54d7e62292966e..59c425339100a2 100644 --- a/Include/cpython/pystate.h +++ b/Include/cpython/pystate.h @@ -208,6 +208,20 @@ struct _ts { */ PyObject *threading_local_sentinel; _PyRemoteDebuggerSupport remote_debugger_support; + + struct { + /* Number of nested PyThreadState_Ensure() calls on this thread state */ + Py_ssize_t counter; + + /* Thread state that was active before PyThreadState_Ensure() was called. */ + PyThreadState *prior_tstate; + + /* Should this thread state be deleted upon calling + PyThreadState_Release() (with the counter at 1)? + + This is only true for thread states created by PyThreadState_Ensure() */ + int delete_on_release; + } ensure; }; /* other API */ @@ -261,3 +275,42 @@ PyAPI_FUNC(_PyFrameEvalFunction) _PyInterpreterState_GetEvalFrameFunc( PyAPI_FUNC(void) _PyInterpreterState_SetEvalFrameFunc( PyInterpreterState *interp, _PyFrameEvalFunction eval_frame); + +/* Strong interpreter references */ + +typedef uintptr_t PyInterpreterRef; + +PyAPI_FUNC(int) PyInterpreterRef_Get(PyInterpreterRef *ref); +PyAPI_FUNC(PyInterpreterRef) PyInterpreterRef_Dup(PyInterpreterRef ref); +PyAPI_FUNC(int) PyInterpreterRef_Main(PyInterpreterRef *ref); +PyAPI_FUNC(void) PyInterpreterRef_Close(PyInterpreterRef ref); +PyAPI_FUNC(PyInterpreterState *) PyInterpreterRef_AsInterpreter(PyInterpreterRef ref); + +#define PyInterpreterRef_Close(ref) do { \ + PyInterpreterRef_Close(ref); \ + ref = 0; \ +} while (0) + +/* Weak interpreter references */ + +typedef struct _PyInterpreterWeakRef { + int64_t id; + Py_ssize_t refcount; +} _PyInterpreterWeakRef; + +typedef _PyInterpreterWeakRef *PyInterpreterWeakRef; + +PyAPI_FUNC(int) PyInterpreterWeakRef_Get(PyInterpreterWeakRef *ptr); +PyAPI_FUNC(PyInterpreterWeakRef) PyInterpreterWeakRef_Dup(PyInterpreterWeakRef wref); +PyAPI_FUNC(int) PyInterpreterWeakRef_AsStrong(PyInterpreterWeakRef wref, PyInterpreterRef *strong_ptr); +PyAPI_FUNC(void) PyInterpreterWeakRef_Close(PyInterpreterWeakRef wref); + +#define PyInterpreterWeakRef_Close(ref) do { \ + PyInterpreterWeakRef_Close(ref); \ + ref = 0; \ +} while (0) + + +PyAPI_FUNC(int) PyThreadState_Ensure(PyInterpreterRef interp_ref); + +PyAPI_FUNC(void) PyThreadState_Release(void); diff --git a/Include/internal/pycore_interp_structs.h b/Include/internal/pycore_interp_structs.h index f1f427d99dea69..feed971e921e20 100644 --- a/Include/internal/pycore_interp_structs.h +++ b/Include/internal/pycore_interp_structs.h @@ -810,6 +810,13 @@ struct _is { or the size specified by the THREAD_STACK_SIZE macro. */ /* Used in Python/thread.c. */ size_t stacksize; + + struct _Py_finalizing_threads { + Py_ssize_t countdown; + PyEvent finished; + PyMutex mutex; + int shutting_down; + } finalizing; } threads; /* Reference to the _PyRuntime global variable. This field exists diff --git a/Include/internal/pycore_pystate.h b/Include/internal/pycore_pystate.h index 633e5cf77db918..9f6ec120317592 100644 --- a/Include/internal/pycore_pystate.h +++ b/Include/internal/pycore_pystate.h @@ -328,6 +328,10 @@ _Py_RecursionLimit_GetMargin(PyThreadState *tstate) return Py_ARITHMETIC_RIGHT_SHIFT(intptr_t, here_addr - (intptr_t)_tstate->c_stack_soft_limit, PYOS_STACK_MARGIN_SHIFT); } +// Exports for '_testinternalcapi' shared extension +PyAPI_FUNC(Py_ssize_t) _PyInterpreterState_Refcount(PyInterpreterState *interp); +PyAPI_FUNC(int) _PyInterpreterState_Incref(PyInterpreterState *interp); + #ifdef __cplusplus } #endif diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py index 89f4aebe28f4a1..124213b9ca36e6 100644 --- a/Lib/test/test_embed.py +++ b/Lib/test/test_embed.py @@ -1914,10 +1914,15 @@ def test_audit_run_stdin(self): def test_get_incomplete_frame(self): self.run_embedded_interpreter("test_get_incomplete_frame") - def test_gilstate_after_finalization(self): self.run_embedded_interpreter("test_gilstate_after_finalization") + def test_thread_state_ensure(self): + self.run_embedded_interpreter("test_thread_state_ensure") + + def test_main_interpreter_ref(self): + self.run_embedded_interpreter("test_main_interpreter_ref") + class MiscTests(EmbeddingTestsMixin, unittest.TestCase): def test_unicode_id_init(self): diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index 71fffedee146fa..f85667f91f01ff 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -2546,6 +2546,194 @@ toggle_reftrace_printer(PyObject *ob, PyObject *arg) Py_RETURN_NONE; } +static PyInterpreterRef +get_strong_ref(void) +{ + PyInterpreterRef ref; + if (PyInterpreterRef_Get(&ref) < 0) { + Py_FatalError("strong reference should not have failed"); + } + return ref; +} + +static void +test_interp_ref_common(void) +{ + PyInterpreterState *interp = PyInterpreterState_Get(); + PyInterpreterRef ref = get_strong_ref(); + assert(PyInterpreterRef_AsInterpreter(ref) == interp); + + PyInterpreterRef ref_2 = PyInterpreterRef_Dup(ref); + assert(PyInterpreterRef_AsInterpreter(ref_2) == interp); + + // We can close the references in any order + PyInterpreterRef_Close(ref); + PyInterpreterRef_Close(ref_2); +} + +static PyObject * +test_interpreter_refs(PyObject *self, PyObject *unused) +{ + // Test the main interpreter + test_interp_ref_common(); + + // Test a (legacy) subinterpreter + PyThreadState *save_tstate = PyThreadState_Swap(NULL); + PyThreadState *interp_tstate = Py_NewInterpreter(); + test_interp_ref_common(); + Py_EndInterpreter(interp_tstate); + + // Test an isolated subinterpreter + PyInterpreterConfig config = { + .gil = PyInterpreterConfig_OWN_GIL, + .check_multi_interp_extensions = 1 + }; + + PyThreadState *isolated_interp_tstate; + PyStatus status = Py_NewInterpreterFromConfig(&isolated_interp_tstate, &config); + if (PyStatus_Exception(status)) { + PyErr_SetString(PyExc_RuntimeError, "interpreter creation failed"); + return NULL; + } + + test_interp_ref_common(); + Py_EndInterpreter(isolated_interp_tstate); + PyThreadState_Swap(save_tstate); + Py_RETURN_NONE; +} + +static PyObject * +test_thread_state_ensure_nested(PyObject *self, PyObject *unused) +{ + PyInterpreterRef ref = get_strong_ref(); + PyThreadState *save_tstate = PyThreadState_Swap(NULL); + assert(PyGILState_GetThisThreadState() == save_tstate); + + for (int i = 0; i < 10; ++i) { + // Test reactivation of the detached tstate. + if (PyThreadState_Ensure(ref) < 0) { + PyInterpreterRef_Close(ref); + return PyErr_NoMemory(); + } + + // No new thread state should've been created. + assert(PyThreadState_Get() == save_tstate); + PyThreadState_Release(); + } + + assert(PyThreadState_GetUnchecked() == NULL); + + // Similarly, test ensuring with deep nesting and *then* releasing. + // If the (detached) gilstate matches the interpreter, then it shouldn't + // create a new thread state. + for (int i = 0; i < 10; ++i) { + if (PyThreadState_Ensure(ref) < 0) { + // This will technically leak other thread states, but it doesn't + // matter because this is a test. + PyInterpreterRef_Close(ref); + return PyErr_NoMemory(); + } + + assert(PyThreadState_Get() == save_tstate); + } + + for (int i = 0; i < 10; ++i) { + assert(PyThreadState_Get() == save_tstate); + PyThreadState_Release(); + } + + assert(PyThreadState_GetUnchecked() == NULL); + PyInterpreterRef_Close(ref); + PyThreadState_Swap(save_tstate); + Py_RETURN_NONE; +} + +static PyObject * +test_thread_state_ensure_crossinterp(PyObject *self, PyObject *unused) +{ + PyInterpreterRef ref = get_strong_ref(); + PyThreadState *save_tstate = PyThreadState_Swap(NULL); + PyThreadState *interp_tstate = Py_NewInterpreter(); + if (interp_tstate == NULL) { + PyInterpreterRef_Close(ref); + return PyErr_NoMemory(); + } + + /* This should create a new thread state for the calling interpreter, *not* + reactivate the old one. In a real-world scenario, this would arise in + something like this: + + def some_func(): + import something + # This re-enters the main interpreter, but we + # shouldn't have access to prior thread-locals. + something.call_something() + + interp = interpreters.create() + interp.exec(some_func) + */ + if (PyThreadState_Ensure(ref) < 0) { + PyInterpreterRef_Close(ref); + return PyErr_NoMemory(); + } + + PyThreadState *ensured_tstate = PyThreadState_Get(); + assert(ensured_tstate != save_tstate); + assert(PyInterpreterState_Get() == PyInterpreterRef_AsInterpreter(ref)); + assert(PyGILState_GetThisThreadState() == ensured_tstate); + + // Now though, we should reactivate the thread state + if (PyThreadState_Ensure(ref) < 0) { + PyInterpreterRef_Close(ref); + return PyErr_NoMemory(); + } + + assert(PyThreadState_Get() == ensured_tstate); + PyThreadState_Release(); + + // Ensure that we're restoring the prior thread state + PyThreadState_Release(); + assert(PyThreadState_Get() == interp_tstate); + assert(PyGILState_GetThisThreadState() == interp_tstate); + + PyThreadState_Swap(interp_tstate); + Py_EndInterpreter(interp_tstate); + + PyInterpreterRef_Close(ref); + PyThreadState_Swap(save_tstate); + Py_RETURN_NONE; +} + +static PyObject * +test_weak_interpreter_ref_after_shutdown(PyObject *self, PyObject *unused) +{ + PyThreadState *save_tstate = PyThreadState_Swap(NULL); + PyInterpreterWeakRef wref; + PyThreadState *interp_tstate = Py_NewInterpreter(); + if (interp_tstate == NULL) { + return PyErr_NoMemory(); + } + + int res = PyInterpreterWeakRef_Get(&wref); + (void)res; + assert(res == 0); + + // As a sanity check, ensure that the weakref actually works + PyInterpreterRef ref; + res = PyInterpreterWeakRef_AsStrong(wref, &ref); + assert(res == 0); + PyInterpreterRef_Close(ref); + + // Now, destroy the interpreter and try to acquire a weak reference. + // It should fail. + Py_EndInterpreter(interp_tstate); + res = PyInterpreterWeakRef_AsStrong(wref, &ref); + assert(res == -1); + + PyThreadState_Swap(save_tstate); + Py_RETURN_NONE; +} + static PyMethodDef TestMethods[] = { {"set_errno", set_errno, METH_VARARGS}, {"test_config", test_config, METH_NOARGS}, @@ -2640,6 +2828,10 @@ static PyMethodDef TestMethods[] = { {"test_atexit", test_atexit, METH_NOARGS}, {"code_offset_to_line", _PyCFunction_CAST(code_offset_to_line), METH_FASTCALL}, {"toggle_reftrace_printer", toggle_reftrace_printer, METH_O}, + {"test_interpreter_refs", test_interpreter_refs, METH_NOARGS}, + {"test_thread_state_ensure_nested", test_thread_state_ensure_nested, METH_NOARGS}, + {"test_thread_state_ensure_crossinterp", test_thread_state_ensure_crossinterp, METH_NOARGS}, + {"test_weak_interpreter_ref_after_shutdown", test_weak_interpreter_ref_after_shutdown, METH_NOARGS}, {NULL, NULL} /* sentinel */ }; diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index 804cb4e4d1c8ee..de0c179ae30b07 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -2345,6 +2345,60 @@ incref_decref_delayed(PyObject *self, PyObject *op) Py_RETURN_NONE; } +#define NUM_REFS 100 + +static PyObject * +test_interp_refcount(PyObject *self, PyObject *unused) +{ + PyInterpreterState *interp = PyInterpreterState_Get(); + assert(_PyInterpreterState_Refcount(interp) == 0); + PyInterpreterRef refs[NUM_REFS]; + for (int i = 0; i < NUM_REFS; ++i) { + int res = PyInterpreterRef_Get(&refs[i]); + (void)res; + assert(res == 0); + assert(_PyInterpreterState_Refcount(interp) == i + 1); + } + + for (int i = 0; i < NUM_REFS; ++i) { + PyInterpreterRef_Close(refs[i]); + assert(_PyInterpreterState_Refcount(interp) == (NUM_REFS - i - 1)); + } + + Py_RETURN_NONE; +} + +static PyObject * +test_interp_weakref_incref(PyObject *self, PyObject *unused) +{ + PyInterpreterState *interp = PyInterpreterState_Get(); + PyInterpreterWeakRef wref; + if (PyInterpreterWeakRef_Get(&wref) < 0) { + return NULL; + } + assert(_PyInterpreterState_Refcount(interp) == 0); + + PyInterpreterRef refs[NUM_REFS]; + + for (int i = 0; i < NUM_REFS; ++i) { + int res = PyInterpreterWeakRef_AsStrong(wref, &refs[i]); + (void)res; + assert(res == 0); + assert(PyInterpreterRef_AsInterpreter(refs[i]) == interp); + assert(_PyInterpreterState_Refcount(interp) == i + 1); + } + + for (int i = 0; i < NUM_REFS; ++i) { + PyInterpreterRef_Close(refs[i]); + assert(_PyInterpreterState_Refcount(interp) == (NUM_REFS - i - 1)); + } + + PyInterpreterWeakRef_Close(wref); + Py_RETURN_NONE; +} + +#undef NUM_REFS + static PyMethodDef module_functions[] = { {"get_configs", get_configs, METH_NOARGS}, {"get_recursion_depth", get_recursion_depth, METH_NOARGS}, @@ -2447,6 +2501,8 @@ static PyMethodDef module_functions[] = { {"is_static_immortal", is_static_immortal, METH_O}, {"incref_decref_delayed", incref_decref_delayed, METH_O}, GET_NEXT_DICT_KEYS_VERSION_METHODDEF + {"test_interp_refcount", test_interp_refcount, METH_NOARGS}, + {"test_interp_weakref_incref", test_interp_weakref_incref, METH_NOARGS}, {NULL, NULL} /* sentinel */ }; diff --git a/Programs/_testembed.c b/Programs/_testembed.c index 577da65c7cdafa..3d98e8f0b79d28 100644 --- a/Programs/_testembed.c +++ b/Programs/_testembed.c @@ -2313,6 +2313,66 @@ test_get_incomplete_frame(void) return result; } +const char *THREAD_CODE = \ + "import time\n" + "time.sleep(0.2)\n" + "def fib(n):\n" + " if n <= 1:\n" + " return n\n" + " else:\n" + " return fib(n - 1) + fib(n - 2)\n" + "fib(10)"; + +typedef struct { + PyInterpreterRef ref; + int done; +} ThreadData; + +static void +do_tstate_ensure(void *arg) +{ + ThreadData *data = (ThreadData *)arg; + int res = PyThreadState_Ensure(data->ref); + assert(res == 0); + PyThreadState_Ensure(data->ref); + PyThreadState_Ensure(data->ref); + PyGILState_STATE gstate = PyGILState_Ensure(); + PyThreadState_Ensure(data->ref); + res = PyRun_SimpleString(THREAD_CODE); + PyThreadState_Release(); + PyGILState_Release(gstate); + PyThreadState_Release(); + PyThreadState_Release(); + assert(res == 0); + PyThreadState_Release(); + PyInterpreterRef_Close(data->ref); + data->done = 1; +} + +static int +test_thread_state_ensure(void) +{ + _testembed_initialize(); + PyThread_handle_t handle; + PyThread_ident_t ident; + PyInterpreterRef ref; + if (PyInterpreterRef_Get(&ref) < 0) { + return -1; + }; + ThreadData data = { ref }; + if (PyThread_start_joinable_thread(do_tstate_ensure, &data, + &ident, &handle) < 0) { + PyInterpreterRef_Close(ref); + return -1; + } + // We hold a strong interpreter reference, so we don't + // have to worry about the interpreter shutting down before + // we finalize. + Py_Finalize(); + assert(data.done == 1); + return 0; +} + static void do_gilstate_ensure(void *event_ptr) { @@ -2340,6 +2400,31 @@ test_gilstate_after_finalization(void) return PyThread_detach_thread(handle); } +static int +test_main_interpreter_ref(void) +{ + // It should not work before the runtime has started. + PyInterpreterRef ref; + int res = PyInterpreterRef_Main(&ref); + (void)res; + assert(res == -1); + + _testembed_initialize(); + + // Main interpreter is initialized and ready. + res = PyInterpreterRef_Main(&ref); + assert(res == 0); + assert(PyInterpreterRef_AsInterpreter(ref) == PyInterpreterState_Main()); + PyInterpreterRef_Close(ref); + + Py_Finalize(); + + // Main interpreter is dead, we can no longer acquire references to it. + res = PyInterpreterRef_Main(&ref); + assert(res == -1); + return 0; +} + /* ********************************************************* * List of test cases and the function that implements it. * @@ -2429,7 +2514,9 @@ static struct TestCase TestCases[] = { {"test_frozenmain", test_frozenmain}, #endif {"test_get_incomplete_frame", test_get_incomplete_frame}, + {"test_thread_state_ensure", test_thread_state_ensure}, {"test_gilstate_after_finalization", test_gilstate_after_finalization}, + {"test_main_interpreter_ref", test_main_interpreter_ref}, {NULL, NULL} }; diff --git a/Python/pylifecycle.c b/Python/pylifecycle.c index 724fda63511282..a456ee3be2cb99 100644 --- a/Python/pylifecycle.c +++ b/Python/pylifecycle.c @@ -97,6 +97,7 @@ static PyStatus init_android_streams(PyThreadState *tstate); static PyStatus init_apple_streams(PyThreadState *tstate); #endif static void wait_for_thread_shutdown(PyThreadState *tstate); +static void wait_for_interp_references(PyInterpreterState *interp); static void finalize_subinterpreters(void); static void call_ll_exitfuncs(_PyRuntimeState *runtime); @@ -2022,6 +2023,9 @@ _Py_Finalize(_PyRuntimeState *runtime) // Wrap up existing "threading"-module-created, non-daemon threads. wait_for_thread_shutdown(tstate); + // Wait for the interpreter's reference count to reach zero + wait_for_interp_references(tstate->interp); + // Make any remaining pending calls. _Py_FinishPendingCalls(tstate); @@ -2438,6 +2442,9 @@ Py_EndInterpreter(PyThreadState *tstate) // Wrap up existing "threading"-module-created, non-daemon threads. wait_for_thread_shutdown(tstate); + // Wait for the interpreter's reference count to reach zero + wait_for_interp_references(tstate->interp); + // Make any remaining pending calls. _Py_FinishPendingCalls(tstate); @@ -3464,6 +3471,47 @@ wait_for_thread_shutdown(PyThreadState *tstate) Py_DECREF(threading); } +/* Wait for the interpreter's reference count to reach zero. + See PEP 788. */ +static void +wait_for_interp_references(PyInterpreterState *interp) +{ + assert(interp != NULL); + struct _Py_finalizing_threads *finalizing = &interp->threads.finalizing; + _Py_atomic_store_int_release(&finalizing->shutting_down, 1); + PyMutex_Lock(&finalizing->mutex); + if (_Py_atomic_load_ssize_relaxed(&finalizing->countdown) == 0) { + // Nothing to do. + PyMutex_Unlock(&finalizing->mutex); + return; + } + PyMutex_Unlock(&finalizing->mutex); + + PyTime_t wait_max = 1000 * 1000 * 100; // 100 milliseconds + PyTime_t wait_ns = 1000; // 1 microsecond + + while (true) { + if (PyEvent_WaitTimed(&finalizing->finished, wait_ns, 1)) { + // Event set + break; + } + + wait_ns *= 2; + wait_ns = Py_MIN(wait_ns, wait_max); + + if (PyErr_CheckSignals()) { + PyErr_FormatUnraisable("Exception ignored while waiting on interpreter shutdown"); + /* The user CTRL+C'd us, bail out without waiting for a reference + count of zero. + + This will probably cause threads to crash, but maybe that's + better than a deadlock. It might be worth intentionally + leaking subinterpreters to prevent some crashes here. */ + break; + } + } +} + int Py_AtExit(void (*func)(void)) { struct _atexit_runtime_state *state = &_PyRuntime.atexit; diff --git a/Python/pystate.c b/Python/pystate.c index 0d4c26f92cec90..bbc0710a5a45a3 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -1300,12 +1300,8 @@ interp_look_up_id(_PyRuntimeState *runtime, int64_t requested_id) return NULL; } -/* Return the interpreter state with the given ID. - - Fail with RuntimeError if the interpreter is not found. */ - -PyInterpreterState * -_PyInterpreterState_LookUpID(int64_t requested_id) +static PyInterpreterState * +_PyInterpreterState_LookUpIDNoErr(int64_t requested_id) { PyInterpreterState *interp = NULL; if (requested_id >= 0) { @@ -1314,6 +1310,18 @@ _PyInterpreterState_LookUpID(int64_t requested_id) interp = interp_look_up_id(runtime, requested_id); HEAD_UNLOCK(runtime); } + return interp; +} + +/* Return the interpreter state with the given ID. + + Fail with RuntimeError if the interpreter is not found. */ + +PyInterpreterState * +_PyInterpreterState_LookUpID(int64_t requested_id) +{ + assert(_PyThreadState_GET() != NULL); + PyInterpreterState *interp = _PyInterpreterState_LookUpIDNoErr(requested_id); if (interp == NULL && !PyErr_Occurred()) { PyErr_Format(PyExc_InterpreterNotFoundError, "unrecognized interpreter ID %lld", requested_id); @@ -1525,7 +1533,6 @@ new_threadstate(PyInterpreterState *interp, int whence) return NULL; } #endif - /* We serialize concurrent creation to protect global state. */ HEAD_LOCK(interp->runtime); @@ -1729,6 +1736,27 @@ PyThreadState_Clear(PyThreadState *tstate) static void decrement_stoptheworld_countdown(struct _stoptheworld_state *stw); +static int +shutting_down_natives(PyInterpreterState *interp) +{ + assert(interp != NULL); + return _Py_atomic_load_int_relaxed(&interp->threads.finalizing.shutting_down); +} + +static void +decref_interpreter(PyInterpreterState *interp) +{ + assert(interp != NULL); + struct _Py_finalizing_threads *finalizing = &interp->threads.finalizing; + Py_ssize_t old = _Py_atomic_add_ssize(&finalizing->countdown, -1); + if (old == 1 && shutting_down_natives(interp)) { + _PyEvent_Notify(&finalizing->finished); + } else if (old <= 0) { + Py_FatalError("interpreter has negative reference count, likely due" + " to an extra PyInterpreterRef_Close()"); + } +} + /* Common code for PyThreadState_Delete() and PyThreadState_DeleteCurrent() */ static void tstate_delete_common(PyThreadState *tstate, int release_gil) @@ -2734,30 +2762,24 @@ PyGILState_Check(void) PyGILState_STATE PyGILState_Ensure(void) { - _PyRuntimeState *runtime = &_PyRuntime; - /* Note that we do not auto-init Python here - apart from potential races with 2 threads auto-initializing, pep-311 spells out other issues. Embedders are expected to have called Py_Initialize(). */ - /* Ensure that _PyEval_InitThreads() and _PyGILState_Init() have been - called by Py_Initialize() - - TODO: This isn't thread-safe. There's no protection here against - concurrent finalization of the interpreter; it's simply a guard - for *after* the interpreter has finalized. - */ - if (!_PyEval_ThreadsInitialized() || runtime->gilstate.autoInterpreterState == NULL) { - PyThread_hang_thread(); - } - PyThreadState *tcur = gilstate_get(); int has_gil; + PyInterpreterRef ref; if (tcur == NULL) { /* Create a new Python thread state for this thread */ // XXX Use PyInterpreterState_EnsureThreadState()? - tcur = new_threadstate(runtime->gilstate.autoInterpreterState, + if (PyInterpreterRef_Main(&ref) < 0) { + // The main interpreter has finished, so we don't have + // any intepreter to make a thread state for. Hang the + // thread to act as failure. + PyThread_hang_thread(); + } + tcur = new_threadstate(PyInterpreterRef_AsInterpreter(ref), _PyThreadState_WHENCE_GILSTATE); if (tcur == NULL) { Py_FatalError("Couldn't create thread-state for new thread"); @@ -2770,12 +2792,14 @@ PyGILState_Ensure(void) assert(tcur->gilstate_counter == 1); tcur->gilstate_counter = 0; has_gil = 0; /* new thread state is never current */ + PyInterpreterRef_Close(ref); } else { has_gil = holds_gil(tcur); } if (!has_gil) { + // XXX Do we need to protect this against finalization? PyEval_RestoreThread(tcur); } @@ -3112,3 +3136,244 @@ _Py_GetMainConfig(void) } return _PyInterpreterState_GetConfig(interp); } + +Py_ssize_t +_PyInterpreterState_Refcount(PyInterpreterState *interp) +{ + assert(interp != NULL); + return _Py_atomic_load_ssize_relaxed(&interp->threads.finalizing.countdown); +} + +int +_PyInterpreterState_Incref(PyInterpreterState *interp) +{ + assert(interp != NULL); + struct _Py_finalizing_threads *finalizing = &interp->threads.finalizing; + assert(_Py_atomic_load_ssize_relaxed(&finalizing->countdown) >= 0); + PyMutex *mutex = &finalizing->mutex; + PyMutex_Lock(mutex); + if (_PyEvent_IsSet(&finalizing->finished)) { + PyMutex_Unlock(mutex); + return -1; + } + + _Py_atomic_add_ssize(&interp->threads.finalizing.countdown, 1); + PyMutex_Unlock(mutex); + return 0; +} + +static PyInterpreterState * +ref_as_interp(PyInterpreterRef ref) +{ + PyInterpreterState *interp = (PyInterpreterState *)ref; + if (interp == NULL) { + Py_FatalError("Got a null interpreter reference, likely due to use after PyInterpreterRef_Close()"); + } + + return interp; +} + +int +PyInterpreterRef_Get(PyInterpreterRef *ref) +{ + assert(ref != NULL); + PyInterpreterState *interp = PyInterpreterState_Get(); + if (_PyInterpreterState_Incref(interp) < 0) { + PyErr_SetString(PyExc_PythonFinalizationError, + "Cannot acquire strong interpreter references anymore"); + return -1; + } + *ref = (PyInterpreterRef)interp; + return 0; +} + +PyInterpreterRef +PyInterpreterRef_Dup(PyInterpreterRef ref) +{ + PyInterpreterState *interp = ref_as_interp(ref); + int res = _PyInterpreterState_Incref(interp); + (void)res; + // We already hold a strong reference, so it shouldn't be possible + // for the interpreter to be at a point where references don't work anymore + assert(res == 0); + return (PyInterpreterRef)interp; +} + +#undef PyInterpreterRef_Close +void +PyInterpreterRef_Close(PyInterpreterRef ref) +{ + PyInterpreterState *interp = ref_as_interp(ref); + decref_interpreter(interp); +} + +PyInterpreterState * +PyInterpreterRef_AsInterpreter(PyInterpreterRef ref) +{ + PyInterpreterState *interp = ref_as_interp(ref); + return interp; +} + +int +PyInterpreterWeakRef_Get(PyInterpreterWeakRef *wref_ptr) +{ + PyInterpreterState *interp = PyInterpreterState_Get(); + /* PyInterpreterWeakRef_Close() can be called without an attached thread + state, so we have to use the raw allocator. */ + _PyInterpreterWeakRef *wref = PyMem_RawMalloc(sizeof(_PyInterpreterWeakRef)); + if (wref == NULL) { + PyErr_NoMemory(); + return -1; + } + wref->refcount = 1; + wref->id = interp->id; + *wref_ptr = (PyInterpreterWeakRef)wref; + return 0; +} + +static _PyInterpreterWeakRef * +wref_handle_as_ptr(PyInterpreterWeakRef wref_handle) +{ + _PyInterpreterWeakRef *wref = (_PyInterpreterWeakRef *)wref_handle; + if (wref == NULL) { + Py_FatalError("Got a null weak interpreter reference, likely due to use after PyInterpreterWeakRef_Close()"); + } + + return wref; +} + +PyInterpreterWeakRef +PyInterpreterWeakRef_Dup(PyInterpreterWeakRef wref_handle) +{ + _PyInterpreterWeakRef *wref = wref_handle_as_ptr(wref_handle); + ++wref->refcount; + return wref; +} + +#undef PyInterpreterWeakRef_Close +void +PyInterpreterWeakRef_Close(PyInterpreterWeakRef wref_handle) +{ + _PyInterpreterWeakRef *wref = wref_handle_as_ptr(wref_handle); + if (--wref->refcount == 0) { + PyMem_RawFree(wref); + } +} + +static int +try_acquire_strong_ref(PyInterpreterState *interp, PyInterpreterRef *strong_ptr) +{ + struct _Py_finalizing_threads *finalizing = &interp->threads.finalizing; + PyMutex *mutex = &finalizing->mutex; + PyMutex_Lock(mutex); // Synchronize TOCTOU with the event flag + if (_PyEvent_IsSet(&finalizing->finished)) { + /* Interpreter has already finished threads */ + *strong_ptr = 0; + return -1; + } + else { + _Py_atomic_add_ssize(&finalizing->countdown, 1); + } + PyMutex_Unlock(mutex); + *strong_ptr = (PyInterpreterRef)interp; + return 0; +} + +int +PyInterpreterWeakRef_AsStrong(PyInterpreterWeakRef wref_handle, PyInterpreterRef *strong_ptr) +{ + assert(strong_ptr != NULL); + _PyInterpreterWeakRef *wref = wref_handle_as_ptr(wref_handle); + int64_t interp_id = wref->id; + /* Interpreters cannot be deleted while we hold the runtime lock. */ + _PyRuntimeState *runtime = &_PyRuntime; + HEAD_LOCK(runtime); + PyInterpreterState *interp = interp_look_up_id(runtime, interp_id); + if (interp == NULL) { + HEAD_UNLOCK(runtime); + *strong_ptr = 0; + return -1; + } + + int res = try_acquire_strong_ref(interp, strong_ptr); + HEAD_UNLOCK(runtime); + return res; +} + +int +PyInterpreterRef_Main(PyInterpreterRef *strong_ptr) +{ + assert(strong_ptr != NULL); + _PyRuntimeState *runtime = &_PyRuntime; + HEAD_LOCK(runtime); + if (runtime->initialized == 0) { + // Main interpreter is not initialized. + // This can be the case before Py_Initialize(), or after Py_Finalize(). + HEAD_UNLOCK(runtime); + return -1; + } + int res = try_acquire_strong_ref(&runtime->_main_interpreter, strong_ptr); + HEAD_UNLOCK(runtime); + + return res; +} + +int +PyThreadState_Ensure(PyInterpreterRef interp_ref) +{ + PyInterpreterState *interp = ref_as_interp(interp_ref); + PyThreadState *attached_tstate = current_fast_get(); + if (attached_tstate != NULL && attached_tstate->interp == interp) { + /* Yay! We already have an attached thread state that matches. */ + ++attached_tstate->ensure.counter; + return 0; + } + + PyThreadState *detached_gilstate = gilstate_get(); + if (detached_gilstate != NULL && detached_gilstate->interp == interp) { + /* There's a detached thread state that works. */ + assert(attached_tstate == NULL); + ++detached_gilstate->ensure.counter; + _PyThreadState_Attach(detached_gilstate); + return 0; + } + + PyThreadState *fresh_tstate = _PyThreadState_NewBound(interp, + _PyThreadState_WHENCE_GILSTATE); + if (fresh_tstate == NULL) { + return -1; + } + fresh_tstate->ensure.counter = 1; + fresh_tstate->ensure.delete_on_release = 1; + + if (attached_tstate != NULL) { + fresh_tstate->ensure.prior_tstate = PyThreadState_Swap(fresh_tstate); + } else { + _PyThreadState_Attach(fresh_tstate); + } + + return 0; +} + +void +PyThreadState_Release(void) +{ + PyThreadState *tstate = current_fast_get(); + _Py_EnsureTstateNotNULL(tstate); + Py_ssize_t remaining = --tstate->ensure.counter; + if (remaining < 0) { + Py_FatalError("PyThreadState_Release() called more times than PyThreadState_Ensure()"); + } + PyThreadState *to_restore = tstate->ensure.prior_tstate; + if (remaining == 0) { + if (tstate->ensure.delete_on_release) { + PyThreadState_Clear(tstate); + PyThreadState_Swap(to_restore); + PyThreadState_Delete(tstate); + } else { + PyThreadState_Swap(to_restore); + } + } + + return; +}