From 770f166edf92b99aa06be88cb1acaab4c78786a7 Mon Sep 17 00:00:00 2001 From: Jeong Yunwon Date: Thu, 16 Jun 2022 20:53:59 +0900 Subject: [PATCH] Fix slot call deadlock --- vm/src/builtins/type.rs | 15 +++++++++------ vm/src/function/protocol.rs | 8 ++++---- vm/src/object/core.rs | 3 ++- vm/src/protocol/buffer.rs | 3 ++- vm/src/protocol/mapping.rs | 7 ++----- vm/src/protocol/number.rs | 5 ++--- vm/src/protocol/sequence.rs | 3 ++- 7 files changed, 23 insertions(+), 21 deletions(-) diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index e0cc7e3370..d4e09f8c27 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -667,11 +667,12 @@ impl GetAttr for PyType { if let Some(ref attr) = mcl_attr { let attr_class = attr.class(); - if attr_class + let has_descr_set = attr_class .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some() - { - if let Some(descr_get) = attr_class.mro_find_map(|cls| cls.slots.descr_get.load()) { + .is_some(); + if has_descr_set { + let descr_get = attr_class.mro_find_map(|cls| cls.slots.descr_get.load()); + if let Some(descr_get) = descr_get { let mcl = mcl.into_owned().into(); return descr_get(attr.clone(), Some(zelf.to_owned().into()), Some(mcl), vm); } @@ -681,7 +682,8 @@ impl GetAttr for PyType { let zelf_attr = zelf.get_attr(name); if let Some(ref attr) = zelf_attr { - if let Some(descr_get) = attr.class().mro_find_map(|cls| cls.slots.descr_get.load()) { + let descr_get = attr.class().mro_find_map(|cls| cls.slots.descr_get.load()); + if let Some(descr_get) = descr_get { drop(mcl); return descr_get(attr.clone(), None, Some(zelf.to_owned().into()), vm); } @@ -745,7 +747,8 @@ impl Callable for PyType { return Ok(obj); } - if let Some(init_method) = obj.class().mro_find_map(|cls| cls.slots.init.load()) { + let init = obj.class().mro_find_map(|cls| cls.slots.init.load()); + if let Some(init_method) = init { init_method(obj.clone(), args, vm)?; } Ok(obj) diff --git a/vm/src/function/protocol.rs b/vm/src/function/protocol.rs index b6035ff968..5bf643201f 100644 --- a/vm/src/function/protocol.rs +++ b/vm/src/function/protocol.rs @@ -84,14 +84,14 @@ where T: TryFromObject, { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let iterfn; - { + let iterfn = { let cls = obj.class(); - iterfn = cls.mro_find_map(|x| x.slots.iter.load()); + let iterfn = cls.mro_find_map(|x| x.slots.iter.load()); if iterfn.is_none() && !cls.has_attr(identifier!(vm, __getitem__)) { return Err(vm.new_type_error(format!("'{}' object is not iterable", cls.name()))); } - } + iterfn + }; Ok(Self { iterable: obj, iterfn, diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index 87ee67fc81..7b6ee4e06a 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -762,7 +762,8 @@ impl PyObject { } // CPython-compatible drop implementation - if let Some(slot_del) = self.class().mro_find_map(|cls| cls.slots.del.load()) { + let del = self.class().mro_find_map(|cls| cls.slots.del.load()); + if let Some(slot_del) = del { call_slot_del(self, slot_del)?; } if let Some(wrl) = self.weak_ref_list() { diff --git a/vm/src/protocol/buffer.rs b/vm/src/protocol/buffer.rs index 9f4c191be0..5e90630139 100644 --- a/vm/src/protocol/buffer.rs +++ b/vm/src/protocol/buffer.rs @@ -140,7 +140,8 @@ impl PyBuffer { impl TryFromBorrowedObject for PyBuffer { fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult { let cls = obj.class(); - if let Some(f) = cls.mro_find_map(|cls| cls.slots.as_buffer) { + let as_buffer = cls.mro_find_map(|cls| cls.slots.as_buffer); + if let Some(f) = as_buffer { return f(obj, vm); } Err(vm.new_type_error(format!( diff --git a/vm/src/protocol/mapping.rs b/vm/src/protocol/mapping.rs index 98a45038fe..d21184a2a9 100644 --- a/vm/src/protocol/mapping.rs +++ b/vm/src/protocol/mapping.rs @@ -146,11 +146,8 @@ impl PyMapping<'_> { } pub fn find_methods(obj: &PyObject, vm: &VirtualMachine) -> Option<&'static PyMappingMethods> { - if let Some(f) = obj.class().mro_find_map(|cls| cls.slots.as_mapping.load()) { - Some(f(obj, vm)) - } else { - None - } + let as_mapping = obj.class().mro_find_map(|cls| cls.slots.as_mapping.load()); + as_mapping.map(|f| f(obj, vm)) } pub fn length_opt(&self, vm: &VirtualMachine) -> Option> { diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 2c2100168f..e8f1e90412 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -163,9 +163,8 @@ impl<'a> PyNumber<'a> { impl PyNumber<'_> { pub fn find_methods(obj: &PyObject, vm: &VirtualMachine) -> Option<&'static PyNumberMethods> { - obj.class() - .mro_find_map(|x| x.slots.as_number.load()) - .map(|f| f(obj, vm)) + let as_number = obj.class().mro_find_map(|x| x.slots.as_number.load()); + as_number.map(|f| f(obj, vm)) } pub fn methods(&self) -> &'static PyNumberMethods { diff --git a/vm/src/protocol/sequence.rs b/vm/src/protocol/sequence.rs index c279dcb8df..eda0b8c7f1 100644 --- a/vm/src/protocol/sequence.rs +++ b/vm/src/protocol/sequence.rs @@ -145,7 +145,8 @@ impl PySequence<'_> { pub fn find_methods(obj: &PyObject, vm: &VirtualMachine) -> Option<&'static PySequenceMethods> { let cls = obj.class(); if !cls.is(vm.ctx.types.dict_type) { - if let Some(f) = cls.mro_find_map(|x| x.slots.as_sequence.load()) { + let as_sequence = cls.mro_find_map(|x| x.slots.as_sequence.load()); + if let Some(f) = as_sequence { return Some(f(obj, vm)); } }