Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 960f87b

Browse files
committed
functools.partial
1 parent 9952c97 commit 960f87b

File tree

2 files changed

+292
-3
lines changed

2 files changed

+292
-3
lines changed

Lib/test/test_functools.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,6 @@ class TestPartialC(TestPartial, unittest.TestCase):
396396
module = c_functools
397397
partial = c_functools.partial
398398

399-
# TODO: RUSTPYTHON
400-
@unittest.expectedFailure
401399
def test_attributes_unwritable(self):
402400
# attributes should not be writable
403401
p = self.partial(capture, 1, 2, a=10, b=20)

vm/src/stdlib/functools.rs

Lines changed: 292 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@ pub(crate) use _functools::make_module;
22

33
#[pymodule]
44
mod _functools {
5-
use crate::{PyObjectRef, PyResult, VirtualMachine, function::OptionalArg, protocol::PyIter};
5+
use crate::{
6+
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
7+
builtins::{PyDict, PyTuple, PyTypeRef},
8+
common::lock::PyRwLock,
9+
function::{FuncArgs, KwArgs, OptionalArg},
10+
object::AsObject,
11+
protocol::PyIter,
12+
pyclass,
13+
recursion::ReprGuard,
14+
types::{Callable, Constructor, Representable},
15+
};
16+
use indexmap::IndexMap;
617

718
#[pyfunction]
819
fn reduce(
@@ -30,4 +41,284 @@ mod _functools {
3041
}
3142
Ok(accumulator)
3243
}
44+
45+
#[pyattr]
46+
#[pyclass(name = "partial", module = "_functools")]
47+
#[derive(Debug, PyPayload)]
48+
pub struct PyPartial {
49+
inner: PyRwLock<PyPartialInner>,
50+
}
51+
52+
#[derive(Debug)]
53+
struct PyPartialInner {
54+
func: PyObjectRef,
55+
args: PyRef<PyTuple>,
56+
keywords: PyRef<PyDict>,
57+
}
58+
59+
#[pyclass(with(Constructor, Callable, Representable), flags(BASETYPE, HAS_DICT))]
60+
impl PyPartial {
61+
#[pygetset]
62+
fn func(&self) -> PyObjectRef {
63+
self.inner.read().func.clone()
64+
}
65+
66+
#[pygetset]
67+
fn args(&self) -> PyRef<PyTuple> {
68+
self.inner.read().args.clone()
69+
}
70+
71+
#[pygetset]
72+
fn keywords(&self) -> PyRef<PyDict> {
73+
self.inner.read().keywords.clone()
74+
}
75+
76+
#[pymethod(name = "__reduce__")]
77+
fn reduce(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult {
78+
let inner = zelf.inner.read();
79+
let partial_type = zelf.class();
80+
81+
// Get __dict__ if it exists and is not empty
82+
let dict_obj = match zelf.as_object().dict() {
83+
Some(dict) if !dict.is_empty() => dict.into(),
84+
_ => vm.ctx.none(),
85+
};
86+
87+
let state = vm.ctx.new_tuple(vec![
88+
inner.func.clone(),
89+
inner.args.clone().into(),
90+
inner.keywords.clone().into(),
91+
dict_obj,
92+
]);
93+
Ok(vm
94+
.ctx
95+
.new_tuple(vec![
96+
partial_type.to_owned().into(),
97+
vm.ctx.new_tuple(vec![inner.func.clone()]).into(),
98+
state.into(),
99+
])
100+
.into())
101+
}
102+
103+
#[pymethod(magic)]
104+
fn setstate(zelf: &Py<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
105+
let state_tuple = state.downcast::<PyTuple>().map_err(|_| {
106+
vm.new_type_error("argument to __setstate__ must be a tuple".to_owned())
107+
})?;
108+
109+
if state_tuple.len() != 4 {
110+
return Err(vm.new_type_error(format!(
111+
"expected 4 items in state, got {}",
112+
state_tuple.len()
113+
)));
114+
}
115+
116+
let func = &state_tuple[0];
117+
let args = &state_tuple[1];
118+
let kwds = &state_tuple[2];
119+
let dict = &state_tuple[3];
120+
121+
if !func.is_callable() {
122+
return Err(vm.new_type_error("invalid partial state".to_owned()));
123+
}
124+
125+
// Validate that args is a tuple (or subclass)
126+
if !args.fast_isinstance(vm.ctx.types.tuple_type) {
127+
return Err(vm.new_type_error("invalid partial state".to_owned()));
128+
}
129+
// Always convert to base tuple, even if it's a subclass
130+
let args_tuple = match args.clone().downcast::<PyTuple>() {
131+
Ok(tuple) if tuple.class().is(vm.ctx.types.tuple_type) => tuple,
132+
_ => {
133+
// It's a tuple subclass, convert to base tuple
134+
let elements: Vec<PyObjectRef> = args.try_to_value(vm)?;
135+
vm.ctx.new_tuple(elements)
136+
}
137+
};
138+
139+
let keywords_dict = if kwds.is(&vm.ctx.none) {
140+
vm.ctx.new_dict()
141+
} else {
142+
// Always convert to base dict, even if it's a subclass
143+
let dict = kwds
144+
.clone()
145+
.downcast::<PyDict>()
146+
.map_err(|_| vm.new_type_error("invalid partial state".to_owned()))?;
147+
if dict.class().is(vm.ctx.types.dict_type) {
148+
// It's already a base dict
149+
dict
150+
} else {
151+
// It's a dict subclass, convert to base dict
152+
let new_dict = vm.ctx.new_dict();
153+
for (key, value) in dict {
154+
new_dict.set_item(&*key, value, vm)?;
155+
}
156+
new_dict
157+
}
158+
};
159+
160+
// Actually update the state
161+
let mut inner = zelf.inner.write();
162+
inner.func = func.clone();
163+
// Handle args - use the already validated tuple
164+
inner.args = args_tuple;
165+
166+
// Handle keywords - keep the original type
167+
inner.keywords = keywords_dict;
168+
169+
// Update __dict__ if provided
170+
let Some(instance_dict) = zelf.as_object().dict() else {
171+
return Ok(());
172+
};
173+
174+
if dict.is(&vm.ctx.none) {
175+
// If dict is None, clear the instance dict
176+
instance_dict.clear();
177+
return Ok(());
178+
}
179+
180+
let dict_obj = dict
181+
.clone()
182+
.downcast::<PyDict>()
183+
.map_err(|_| vm.new_type_error("invalid partial state".to_owned()))?;
184+
185+
// Clear existing dict and update with new values
186+
instance_dict.clear();
187+
for (key, value) in dict_obj {
188+
instance_dict.set_item(&*key, value, vm)?;
189+
}
190+
191+
Ok(())
192+
}
193+
}
194+
195+
impl Constructor for PyPartial {
196+
type Args = FuncArgs;
197+
198+
fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
199+
let (func, args_slice) = args.args.split_first().ok_or_else(|| {
200+
vm.new_type_error("partial expected at least 1 argument, got 0".to_owned())
201+
})?;
202+
203+
if !func.is_callable() {
204+
return Err(vm.new_type_error("the first argument must be callable".to_owned()));
205+
}
206+
207+
// Handle nested partial objects
208+
let (final_func, final_args, final_keywords) =
209+
if let Some(partial) = func.downcast_ref::<PyPartial>() {
210+
let inner = partial.inner.read();
211+
let mut combined_args = inner.args.as_slice().to_vec();
212+
combined_args.extend_from_slice(args_slice);
213+
(inner.func.clone(), combined_args, inner.keywords.clone())
214+
} else {
215+
(func.clone(), args_slice.to_vec(), vm.ctx.new_dict())
216+
};
217+
218+
// Add new keywords
219+
for (key, value) in args.kwargs {
220+
final_keywords.set_item(vm.ctx.intern_str(key.as_str()), value, vm)?;
221+
}
222+
223+
let partial = PyPartial {
224+
inner: PyRwLock::new(PyPartialInner {
225+
func: final_func,
226+
args: vm.ctx.new_tuple(final_args),
227+
keywords: final_keywords,
228+
}),
229+
};
230+
231+
partial.into_ref_with_type(vm, cls).map(Into::into)
232+
}
233+
}
234+
235+
impl Callable for PyPartial {
236+
type Args = FuncArgs;
237+
238+
fn call(zelf: &Py<Self>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
239+
let inner = zelf.inner.read();
240+
let mut combined_args = inner.args.as_slice().to_vec();
241+
combined_args.extend_from_slice(&args.args);
242+
243+
// Merge keywords from self.keywords and args.kwargs
244+
let mut final_kwargs = IndexMap::new();
245+
246+
// Add keywords from self.keywords
247+
for (key, value) in inner.keywords.clone() {
248+
let key_str = key
249+
.downcast::<crate::builtins::PyStr>()
250+
.map_err(|_| vm.new_type_error("keywords must be strings".to_owned()))?;
251+
final_kwargs.insert(key_str.as_str().to_owned(), value);
252+
}
253+
254+
// Add keywords from args.kwargs (these override self.keywords)
255+
for (key, value) in args.kwargs {
256+
final_kwargs.insert(key, value);
257+
}
258+
259+
inner
260+
.func
261+
.call(FuncArgs::new(combined_args, KwArgs::new(final_kwargs)), vm)
262+
}
263+
}
264+
265+
impl Representable for PyPartial {
266+
#[inline]
267+
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
268+
// Check for recursive repr
269+
let obj = zelf.as_object();
270+
if let Some(_guard) = ReprGuard::enter(vm, obj) {
271+
let inner = zelf.inner.read();
272+
let func_repr = inner.func.repr(vm)?;
273+
let mut parts = vec![func_repr.as_str().to_owned()];
274+
275+
for arg in inner.args.as_slice() {
276+
parts.push(arg.repr(vm)?.as_str().to_owned());
277+
}
278+
279+
for (key, value) in inner.keywords.clone() {
280+
// For string keys, use them directly without quotes
281+
let key_part = if let Ok(s) = key.clone().downcast::<crate::builtins::PyStr>() {
282+
s.as_str().to_owned()
283+
} else {
284+
// For non-string keys, convert to string using __str__
285+
key.str(vm)?.as_str().to_owned()
286+
};
287+
let value_str = value.repr(vm)?;
288+
parts.push(format!("{}={}", key_part, value_str.as_str()));
289+
}
290+
291+
let class_name = zelf.class().name();
292+
let module = zelf.class().module(vm);
293+
294+
// Check if this is a subclass by comparing with the base partial type
295+
let is_subclass = !zelf.class().is(PyPartial::class(&vm.ctx));
296+
297+
let qualified_name = if !is_subclass {
298+
// For the base partial class, always use functools.partial
299+
"functools.partial".to_string()
300+
} else {
301+
// For subclasses, check if they're defined in __main__ or test modules
302+
match module.downcast::<crate::builtins::PyStr>() {
303+
Ok(module_str) => {
304+
let module_name = module_str.as_str();
305+
match module_name {
306+
"builtins" | "" | "__main__" => class_name.to_string(),
307+
name if name.starts_with("test.") || name == "test" => {
308+
// For test modules, just use the class name without module prefix
309+
class_name.to_string()
310+
}
311+
_ => format!("{}.{}", module_name, class_name),
312+
}
313+
}
314+
Err(_) => class_name.to_string(),
315+
}
316+
};
317+
318+
Ok(format!("{}({})", qualified_name, parts.join(", ")))
319+
} else {
320+
Ok("...".to_owned())
321+
}
322+
}
323+
}
33324
}

0 commit comments

Comments
 (0)