@@ -2,7 +2,18 @@ pub(crate) use _functools::make_module;
2
2
3
3
#[ pymodule]
4
4
mod _functools {
5
- use crate :: { PyObjectRef , PyResult , VirtualMachine , function:: OptionalArg , protocol:: PyIter } ;
5
+ use crate :: {
6
+ Py , PyObjectRef , PyPayload , PyRef , PyResult , VirtualMachine ,
7
+ builtins:: { PyDict , PyTuple , PyTypeRef } ,
8
+ common:: lock:: PyRwLock ,
9
+ function:: { FuncArgs , KwArgs , OptionalArg } ,
10
+ object:: AsObject ,
11
+ protocol:: PyIter ,
12
+ pyclass,
13
+ recursion:: ReprGuard ,
14
+ types:: { Callable , Constructor , Representable } ,
15
+ } ;
16
+ use indexmap:: IndexMap ;
6
17
7
18
#[ pyfunction]
8
19
fn reduce (
@@ -30,4 +41,284 @@ mod _functools {
30
41
}
31
42
Ok ( accumulator)
32
43
}
44
+
45
+ #[ pyattr]
46
+ #[ pyclass( name = "partial" , module = "_functools" ) ]
47
+ #[ derive( Debug , PyPayload ) ]
48
+ pub struct PyPartial {
49
+ inner : PyRwLock < PyPartialInner > ,
50
+ }
51
+
52
+ #[ derive( Debug ) ]
53
+ struct PyPartialInner {
54
+ func : PyObjectRef ,
55
+ args : PyRef < PyTuple > ,
56
+ keywords : PyRef < PyDict > ,
57
+ }
58
+
59
+ #[ pyclass( with( Constructor , Callable , Representable ) , flags( BASETYPE , HAS_DICT ) ) ]
60
+ impl PyPartial {
61
+ #[ pygetset]
62
+ fn func ( & self ) -> PyObjectRef {
63
+ self . inner . read ( ) . func . clone ( )
64
+ }
65
+
66
+ #[ pygetset]
67
+ fn args ( & self ) -> PyRef < PyTuple > {
68
+ self . inner . read ( ) . args . clone ( )
69
+ }
70
+
71
+ #[ pygetset]
72
+ fn keywords ( & self ) -> PyRef < PyDict > {
73
+ self . inner . read ( ) . keywords . clone ( )
74
+ }
75
+
76
+ #[ pymethod( name = "__reduce__" ) ]
77
+ fn reduce ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult {
78
+ let inner = zelf. inner . read ( ) ;
79
+ let partial_type = zelf. class ( ) ;
80
+
81
+ // Get __dict__ if it exists and is not empty
82
+ let dict_obj = match zelf. as_object ( ) . dict ( ) {
83
+ Some ( dict) if !dict. is_empty ( ) => dict. into ( ) ,
84
+ _ => vm. ctx . none ( ) ,
85
+ } ;
86
+
87
+ let state = vm. ctx . new_tuple ( vec ! [
88
+ inner. func. clone( ) ,
89
+ inner. args. clone( ) . into( ) ,
90
+ inner. keywords. clone( ) . into( ) ,
91
+ dict_obj,
92
+ ] ) ;
93
+ Ok ( vm
94
+ . ctx
95
+ . new_tuple ( vec ! [
96
+ partial_type. to_owned( ) . into( ) ,
97
+ vm. ctx. new_tuple( vec![ inner. func. clone( ) ] ) . into( ) ,
98
+ state. into( ) ,
99
+ ] )
100
+ . into ( ) )
101
+ }
102
+
103
+ #[ pymethod( magic) ]
104
+ fn setstate ( zelf : & Py < Self > , state : PyObjectRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
105
+ let state_tuple = state. downcast :: < PyTuple > ( ) . map_err ( |_| {
106
+ vm. new_type_error ( "argument to __setstate__ must be a tuple" . to_owned ( ) )
107
+ } ) ?;
108
+
109
+ if state_tuple. len ( ) != 4 {
110
+ return Err ( vm. new_type_error ( format ! (
111
+ "expected 4 items in state, got {}" ,
112
+ state_tuple. len( )
113
+ ) ) ) ;
114
+ }
115
+
116
+ let func = & state_tuple[ 0 ] ;
117
+ let args = & state_tuple[ 1 ] ;
118
+ let kwds = & state_tuple[ 2 ] ;
119
+ let dict = & state_tuple[ 3 ] ;
120
+
121
+ if !func. is_callable ( ) {
122
+ return Err ( vm. new_type_error ( "invalid partial state" . to_owned ( ) ) ) ;
123
+ }
124
+
125
+ // Validate that args is a tuple (or subclass)
126
+ if !args. fast_isinstance ( vm. ctx . types . tuple_type ) {
127
+ return Err ( vm. new_type_error ( "invalid partial state" . to_owned ( ) ) ) ;
128
+ }
129
+ // Always convert to base tuple, even if it's a subclass
130
+ let args_tuple = match args. clone ( ) . downcast :: < PyTuple > ( ) {
131
+ Ok ( tuple) if tuple. class ( ) . is ( vm. ctx . types . tuple_type ) => tuple,
132
+ _ => {
133
+ // It's a tuple subclass, convert to base tuple
134
+ let elements: Vec < PyObjectRef > = args. try_to_value ( vm) ?;
135
+ vm. ctx . new_tuple ( elements)
136
+ }
137
+ } ;
138
+
139
+ let keywords_dict = if kwds. is ( & vm. ctx . none ) {
140
+ vm. ctx . new_dict ( )
141
+ } else {
142
+ // Always convert to base dict, even if it's a subclass
143
+ let dict = kwds
144
+ . clone ( )
145
+ . downcast :: < PyDict > ( )
146
+ . map_err ( |_| vm. new_type_error ( "invalid partial state" . to_owned ( ) ) ) ?;
147
+ if dict. class ( ) . is ( vm. ctx . types . dict_type ) {
148
+ // It's already a base dict
149
+ dict
150
+ } else {
151
+ // It's a dict subclass, convert to base dict
152
+ let new_dict = vm. ctx . new_dict ( ) ;
153
+ for ( key, value) in dict {
154
+ new_dict. set_item ( & * key, value, vm) ?;
155
+ }
156
+ new_dict
157
+ }
158
+ } ;
159
+
160
+ // Actually update the state
161
+ let mut inner = zelf. inner . write ( ) ;
162
+ inner. func = func. clone ( ) ;
163
+ // Handle args - use the already validated tuple
164
+ inner. args = args_tuple;
165
+
166
+ // Handle keywords - keep the original type
167
+ inner. keywords = keywords_dict;
168
+
169
+ // Update __dict__ if provided
170
+ let Some ( instance_dict) = zelf. as_object ( ) . dict ( ) else {
171
+ return Ok ( ( ) ) ;
172
+ } ;
173
+
174
+ if dict. is ( & vm. ctx . none ) {
175
+ // If dict is None, clear the instance dict
176
+ instance_dict. clear ( ) ;
177
+ return Ok ( ( ) ) ;
178
+ }
179
+
180
+ let dict_obj = dict
181
+ . clone ( )
182
+ . downcast :: < PyDict > ( )
183
+ . map_err ( |_| vm. new_type_error ( "invalid partial state" . to_owned ( ) ) ) ?;
184
+
185
+ // Clear existing dict and update with new values
186
+ instance_dict. clear ( ) ;
187
+ for ( key, value) in dict_obj {
188
+ instance_dict. set_item ( & * key, value, vm) ?;
189
+ }
190
+
191
+ Ok ( ( ) )
192
+ }
193
+ }
194
+
195
+ impl Constructor for PyPartial {
196
+ type Args = FuncArgs ;
197
+
198
+ fn py_new ( cls : PyTypeRef , args : Self :: Args , vm : & VirtualMachine ) -> PyResult {
199
+ let ( func, args_slice) = args. args . split_first ( ) . ok_or_else ( || {
200
+ vm. new_type_error ( "partial expected at least 1 argument, got 0" . to_owned ( ) )
201
+ } ) ?;
202
+
203
+ if !func. is_callable ( ) {
204
+ return Err ( vm. new_type_error ( "the first argument must be callable" . to_owned ( ) ) ) ;
205
+ }
206
+
207
+ // Handle nested partial objects
208
+ let ( final_func, final_args, final_keywords) =
209
+ if let Some ( partial) = func. downcast_ref :: < PyPartial > ( ) {
210
+ let inner = partial. inner . read ( ) ;
211
+ let mut combined_args = inner. args . as_slice ( ) . to_vec ( ) ;
212
+ combined_args. extend_from_slice ( args_slice) ;
213
+ ( inner. func . clone ( ) , combined_args, inner. keywords . clone ( ) )
214
+ } else {
215
+ ( func. clone ( ) , args_slice. to_vec ( ) , vm. ctx . new_dict ( ) )
216
+ } ;
217
+
218
+ // Add new keywords
219
+ for ( key, value) in args. kwargs {
220
+ final_keywords. set_item ( vm. ctx . intern_str ( key. as_str ( ) ) , value, vm) ?;
221
+ }
222
+
223
+ let partial = PyPartial {
224
+ inner : PyRwLock :: new ( PyPartialInner {
225
+ func : final_func,
226
+ args : vm. ctx . new_tuple ( final_args) ,
227
+ keywords : final_keywords,
228
+ } ) ,
229
+ } ;
230
+
231
+ partial. into_ref_with_type ( vm, cls) . map ( Into :: into)
232
+ }
233
+ }
234
+
235
+ impl Callable for PyPartial {
236
+ type Args = FuncArgs ;
237
+
238
+ fn call ( zelf : & Py < Self > , args : FuncArgs , vm : & VirtualMachine ) -> PyResult {
239
+ let inner = zelf. inner . read ( ) ;
240
+ let mut combined_args = inner. args . as_slice ( ) . to_vec ( ) ;
241
+ combined_args. extend_from_slice ( & args. args ) ;
242
+
243
+ // Merge keywords from self.keywords and args.kwargs
244
+ let mut final_kwargs = IndexMap :: new ( ) ;
245
+
246
+ // Add keywords from self.keywords
247
+ for ( key, value) in inner. keywords . clone ( ) {
248
+ let key_str = key
249
+ . downcast :: < crate :: builtins:: PyStr > ( )
250
+ . map_err ( |_| vm. new_type_error ( "keywords must be strings" . to_owned ( ) ) ) ?;
251
+ final_kwargs. insert ( key_str. as_str ( ) . to_owned ( ) , value) ;
252
+ }
253
+
254
+ // Add keywords from args.kwargs (these override self.keywords)
255
+ for ( key, value) in args. kwargs {
256
+ final_kwargs. insert ( key, value) ;
257
+ }
258
+
259
+ inner
260
+ . func
261
+ . call ( FuncArgs :: new ( combined_args, KwArgs :: new ( final_kwargs) ) , vm)
262
+ }
263
+ }
264
+
265
+ impl Representable for PyPartial {
266
+ #[ inline]
267
+ fn repr_str ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < String > {
268
+ // Check for recursive repr
269
+ let obj = zelf. as_object ( ) ;
270
+ if let Some ( _guard) = ReprGuard :: enter ( vm, obj) {
271
+ let inner = zelf. inner . read ( ) ;
272
+ let func_repr = inner. func . repr ( vm) ?;
273
+ let mut parts = vec ! [ func_repr. as_str( ) . to_owned( ) ] ;
274
+
275
+ for arg in inner. args . as_slice ( ) {
276
+ parts. push ( arg. repr ( vm) ?. as_str ( ) . to_owned ( ) ) ;
277
+ }
278
+
279
+ for ( key, value) in inner. keywords . clone ( ) {
280
+ // For string keys, use them directly without quotes
281
+ let key_part = if let Ok ( s) = key. clone ( ) . downcast :: < crate :: builtins:: PyStr > ( ) {
282
+ s. as_str ( ) . to_owned ( )
283
+ } else {
284
+ // For non-string keys, convert to string using __str__
285
+ key. str ( vm) ?. as_str ( ) . to_owned ( )
286
+ } ;
287
+ let value_str = value. repr ( vm) ?;
288
+ parts. push ( format ! ( "{}={}" , key_part, value_str. as_str( ) ) ) ;
289
+ }
290
+
291
+ let class_name = zelf. class ( ) . name ( ) ;
292
+ let module = zelf. class ( ) . module ( vm) ;
293
+
294
+ // Check if this is a subclass by comparing with the base partial type
295
+ let is_subclass = !zelf. class ( ) . is ( PyPartial :: class ( & vm. ctx ) ) ;
296
+
297
+ let qualified_name = if !is_subclass {
298
+ // For the base partial class, always use functools.partial
299
+ "functools.partial" . to_string ( )
300
+ } else {
301
+ // For subclasses, check if they're defined in __main__ or test modules
302
+ match module. downcast :: < crate :: builtins:: PyStr > ( ) {
303
+ Ok ( module_str) => {
304
+ let module_name = module_str. as_str ( ) ;
305
+ match module_name {
306
+ "builtins" | "" | "__main__" => class_name. to_string ( ) ,
307
+ name if name. starts_with ( "test." ) || name == "test" => {
308
+ // For test modules, just use the class name without module prefix
309
+ class_name. to_string ( )
310
+ }
311
+ _ => format ! ( "{}.{}" , module_name, class_name) ,
312
+ }
313
+ }
314
+ Err ( _) => class_name. to_string ( ) ,
315
+ }
316
+ } ;
317
+
318
+ Ok ( format ! ( "{}({})" , qualified_name, parts. join( ", " ) ) )
319
+ } else {
320
+ Ok ( "..." . to_owned ( ) )
321
+ }
322
+ }
323
+ }
33
324
}
0 commit comments