Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5bb1007

Browse files
committed
ensured, that __call__ can be inherited through multiple levels of hierarchy
1 parent e964c6d commit 5bb1007

File tree

2 files changed

+58
-34
lines changed

2 files changed

+58
-34
lines changed

src/embed_tests/CallableObject.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public void Dispose() {
2323

2424
[Test]
2525
public void CallMethodMakesObjectCallable() {
26-
var doubler = new Doubler();
26+
var doubler = new DerivedDoubler();
2727
using (Py.GIL()) {
2828
dynamic applyObjectTo21 = PythonEngine.Eval("lambda o: o(21)");
2929
Assert.AreEqual(doubler.__call__(21), (int)applyObjectTo21(doubler.ToPython()));
@@ -43,13 +43,17 @@ class Doubler {
4343
public int __call__(int arg) => 2 * arg;
4444
}
4545

46+
class DerivedDoubler : Doubler { }
47+
4648
[CustomBaseType]
4749
class CallViaInheritance {
4850
public const string BaseClassName = "Forwarder";
4951
public static readonly string BaseClassSource = $@"
50-
class {BaseClassName}:
52+
class MyCallableBase:
5153
def __call__(self, val):
5254
return self.Call(val)
55+
56+
class {BaseClassName}(MyCallableBase): pass
5357
";
5458
public int Call(int arg) => 3 * arg;
5559
}

src/runtime/classobject.cs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Diagnostics;
4+
using System.Linq;
45
using System.Reflection;
56
using System.Runtime.InteropServices;
67

@@ -297,44 +298,31 @@ public static IntPtr tp_call(IntPtr ob, IntPtr args, IntPtr kw)
297298

298299
if (cb.type != typeof(Delegate))
299300
{
300-
IntPtr dict = Marshal.ReadIntPtr(tp, TypeOffset.tp_dict);
301-
IntPtr methodObjectHandle = Runtime.PyDict_GetItemString(dict, "__call__");
302-
if (methodObjectHandle == IntPtr.Zero || methodObjectHandle == Runtime.PyNone)
303-
{
304-
Exceptions.SetError(Exceptions.TypeError, "object is not callable");
305-
return IntPtr.Zero;
306-
}
307-
308-
if (GetManagedObject(methodObjectHandle) is MethodObject methodObject)
309-
{
310-
return methodObject.Invoke(ob, args, kw);
301+
var calls = cb.type.GetMethods(BindingFlags.Public | BindingFlags.Instance)
302+
.Where(m => m.Name == "__call__")
303+
.ToList();
304+
if (calls.Count > 0) {
305+
var callBinder = new MethodBinder();
306+
foreach (MethodInfo call in calls) {
307+
callBinder.AddMethod(call);
308+
}
309+
return callBinder.Invoke(ob, args, kw);
311310
}
312311

313-
methodObjectHandle = IntPtr.Zero;
314-
315-
foreach (IntPtr pythonBase in GetPythonBases(tp)) {
316-
dict = Marshal.ReadIntPtr(pythonBase, TypeOffset.tp_dict);
312+
using var super = new PyObject(Runtime.SelfIncRef(Runtime.PySuper));
313+
using var self = new PyObject(Runtime.SelfIncRef(ob));
314+
using var none = new PyObject(Runtime.SelfIncRef(Runtime.PyNone));
315+
foreach (IntPtr managedTypeDerivingFromPython in GetTypesWithPythonBasesInHierarchy(tp)) {
316+
using var @base = super.Invoke(new PyObject(managedTypeDerivingFromPython), self);
317+
using var call = @base.GetAttrOrElse("__call__", none);
317318

318-
methodObjectHandle = Runtime.PyDict_GetItemString(dict, "__call__");
319-
if (methodObjectHandle != IntPtr.Zero && methodObjectHandle != Runtime.PyNone) break;
320-
}
319+
if (call.Handle == Runtime.PyNone) continue;
321320

322-
if (methodObjectHandle == IntPtr.Zero || methodObjectHandle == Runtime.PyNone) {
323-
Exceptions.SetError(Exceptions.TypeError, "object is not callable");
324-
return IntPtr.Zero;
321+
return Runtime.PyObject_Call(call.Handle, args, kw);
325322
}
326323

327-
var boundMethod = Runtime.PyMethod_New(methodObjectHandle, ob);
328-
if (boundMethod == IntPtr.Zero) { return IntPtr.Zero; }
329-
330-
try
331-
{
332-
return Runtime.PyObject_Call(boundMethod, args, kw);
333-
}
334-
finally
335-
{
336-
Runtime.XDecref(boundMethod);
337-
}
324+
Exceptions.SetError(Exceptions.TypeError, "object is not callable");
325+
return IntPtr.Zero;
338326
}
339327

340328
var co = (CLRObject)GetManagedObject(ob);
@@ -377,6 +365,38 @@ internal static IEnumerable<IntPtr> GetPythonBases(IntPtr tp) {
377365
yield return tp;
378366
}
379367

368+
internal static IEnumerable<IntPtr> GetTypesWithPythonBasesInHierarchy(IntPtr tp) {
369+
Debug.Assert(IsManagedType(tp));
370+
371+
var candidateQueue = new Queue<IntPtr>();
372+
candidateQueue.Enqueue(tp);
373+
while (candidateQueue.Count > 0) {
374+
tp = candidateQueue.Dequeue();
375+
IntPtr bases = Marshal.ReadIntPtr(tp, TypeOffset.tp_bases);
376+
if (bases != IntPtr.Zero) {
377+
long baseCount = Runtime.PyTuple_Size(bases);
378+
bool hasPythonBase = false;
379+
for (long baseIndex = 0; baseIndex < baseCount; baseIndex++) {
380+
IntPtr @base = Runtime.PyTuple_GetItem(bases, baseIndex);
381+
if (IsManagedType(@base)) {
382+
candidateQueue.Enqueue(@base);
383+
} else {
384+
hasPythonBase = true;
385+
}
386+
}
387+
388+
if (hasPythonBase) yield return tp;
389+
} else {
390+
tp = Marshal.ReadIntPtr(tp, TypeOffset.tp_base);
391+
if (tp != IntPtr.Zero && IsManagedType(tp))
392+
candidateQueue.Enqueue(tp);
393+
}
394+
}
395+
}
396+
397+
/// <summary>
398+
/// Checks if specified type is a CLR type
399+
/// </summary>
380400
internal static bool IsManagedType(IntPtr tp)
381401
{
382402
var flags = Util.ReadCLong(tp, TypeOffset.tp_flags);

0 commit comments

Comments
 (0)