@@ -195,23 +195,23 @@ context_switched(PyThreadState *ts)
195195}
196196
197197
198+ // ts is not required to belong to the calling thread.
198199static int
199200_PyContext_Enter (PyThreadState * ts , PyObject * octx )
200201{
201202 ENSURE_Context (octx , -1 )
202203 PyContext * ctx = (PyContext * )octx ;
203204
204205 if (ctx -> ctx_entered ) {
205- _PyErr_Format ( ts , PyExc_RuntimeError ,
206- "cannot enter context: %R is already entered" , ctx );
206+ PyErr_Format ( PyExc_RuntimeError ,
207+ "cannot enter context: %R is already entered" , ctx );
207208 return -1 ;
208209 }
209210
210211 ctx -> ctx_prev = (PyContext * )ts -> context ; /* borrow */
211212 ctx -> ctx_entered = 1 ;
212213
213214 ts -> context = Py_NewRef (ctx );
214- context_switched (ts );
215215 return 0 ;
216216}
217217
@@ -221,10 +221,15 @@ PyContext_Enter(PyObject *octx)
221221{
222222 PyThreadState * ts = _PyThreadState_GET ();
223223 assert (ts != NULL );
224- return _PyContext_Enter (ts , octx );
224+ if (_PyContext_Enter (ts , octx )) {
225+ return -1 ;
226+ }
227+ context_switched (ts );
228+ return 0 ;
225229}
226230
227231
232+ // ts is not required to belong to the calling thread.
228233static int
229234_PyContext_Exit (PyThreadState * ts , PyObject * octx )
230235{
@@ -250,7 +255,6 @@ _PyContext_Exit(PyThreadState *ts, PyObject *octx)
250255 ctx -> ctx_prev = NULL ;
251256 ctx -> ctx_entered = 0 ;
252257 ctx -> ctx_owned_by_thread = 0 ;
253- context_switched (ts );
254258 return 0 ;
255259}
256260
@@ -259,25 +263,36 @@ PyContext_Exit(PyObject *octx)
259263{
260264 PyThreadState * ts = _PyThreadState_GET ();
261265 assert (ts != NULL );
262- return _PyContext_Exit (ts , octx );
266+ if (_PyContext_Exit (ts , octx )) {
267+ return -1 ;
268+ }
269+ context_switched (ts );
270+ return 0 ;
263271}
264272
265273
266274void
267275_PyContext_ExitThreadOwned (PyThreadState * ts )
268276{
269277 assert (ts != NULL );
270- // notify_context_watchers requires the notification to come from the
271- // affected thread, so we can only exit the context(s) if ts belongs to the
272- // current thread.
273- _Bool on_thread = ts == _PyThreadState_GET ();
274278 while (ts -> context != NULL
275279 && PyContext_CheckExact (ts -> context )
276- && ((PyContext * )ts -> context )-> ctx_owned_by_thread
277- && on_thread ) {
280+ && ((PyContext * )ts -> context )-> ctx_owned_by_thread ) {
278281 if (_PyContext_Exit (ts , ts -> context )) {
282+ // Exiting a context that is already known to be at the top of the
283+ // stack cannot fail.
279284 Py_UNREACHABLE ();
280285 }
286+ // notify_context_watchers() requires the notification to come from the
287+ // affected thread, so context_switched() must not be called if ts
288+ // doesn't belong to the current thread. However, it's OK to skip
289+ // calling it in this case: this function is only called when resetting
290+ // a PyThreadState, so if the calling thread doesn't own ts, then the
291+ // owning thread must not be running anymore (it must have just
292+ // finished because a thread-owned context exists here).
293+ if (ts == _PyThreadState_GET ()) {
294+ context_switched (ts );
295+ }
281296 }
282297 if (ts -> context != NULL ) {
283298 // This intentionally does not use tstate variants of these functions
@@ -518,18 +533,15 @@ context_get(void)
518533 assert (ts != NULL );
519534 if (ts -> context == NULL ) {
520535 PyContext * ctx = context_new_empty ();
521- if (ctx != NULL ) {
522- if (_PyContext_Enter (ts , (PyObject * )ctx )) {
523- Py_UNREACHABLE ();
524- }
525- ctx -> ctx_owned_by_thread = 1 ;
536+ if (ctx == NULL || _PyContext_Enter (ts , (PyObject * )ctx )) {
537+ return NULL ;
526538 }
539+ ctx -> ctx_owned_by_thread = 1 ;
527540 assert (ts -> context == (PyObject * )ctx );
528541 Py_CLEAR (ctx ); // _PyContext_Enter created its own ref.
542+ context_switched (ts );
529543 }
530- // The current context may be NULL if the above context_new_empty() call
531- // failed.
532- assert (ts -> context == NULL || PyContext_CheckExact (ts -> context ));
544+ assert (PyContext_CheckExact (ts -> context ));
533545 return (PyContext * )ts -> context ;
534546}
535547
@@ -759,6 +771,7 @@ context_run(PyContext *self, PyObject *const *args,
759771 if (_PyContext_Enter (ts , (PyObject * )self )) {
760772 return NULL ;
761773 }
774+ context_switched (ts );
762775
763776 PyObject * call_result = _PyObject_VectorcallTstate (
764777 ts , args [0 ], args + 1 , nargs - 1 , kwnames );
@@ -767,6 +780,7 @@ context_run(PyContext *self, PyObject *const *args,
767780 Py_XDECREF (call_result );
768781 return NULL ;
769782 }
783+ context_switched (ts );
770784
771785 return call_result ;
772786}
0 commit comments