diff --git a/src/embed_tests/Codecs.cs b/src/embed_tests/Codecs.cs index 600215cf0..90d13fb1c 100644 --- a/src/embed_tests/Codecs.cs +++ b/src/embed_tests/Codecs.cs @@ -6,28 +6,34 @@ namespace Python.EmbeddingTest { using Python.Runtime; using Python.Runtime.Codecs; - public class Codecs { + public class Codecs + { [SetUp] - public void SetUp() { + public void SetUp() + { PythonEngine.Initialize(); } [TearDown] - public void Dispose() { + public void Dispose() + { PythonEngine.Shutdown(); } [Test] - public void ConversionsGeneric() { + public void ConversionsGeneric() + { ConversionsGeneric, ValueTuple>(); } - static void ConversionsGeneric() { + static void ConversionsGeneric() + { TupleCodec.Register(); var tuple = Activator.CreateInstance(typeof(T), 42, "42", new object()); T restored = default; using (Py.GIL()) - using (var scope = Py.CreateScope()) { + using (var scope = Py.CreateScope()) + { void Accept(T value) => restored = value; var accept = new Action(Accept).ToPython(); scope.Set(nameof(tuple), tuple); @@ -38,15 +44,18 @@ static void ConversionsGeneric() { } [Test] - public void ConversionsObject() { + public void ConversionsObject() + { ConversionsObject, ValueTuple>(); } - static void ConversionsObject() { + static void ConversionsObject() + { TupleCodec.Register(); var tuple = Activator.CreateInstance(typeof(T), 42, "42", new object()); T restored = default; using (Py.GIL()) - using (var scope = Py.CreateScope()) { + using (var scope = Py.CreateScope()) + { void Accept(object value) => restored = (T)value; var accept = new Action(Accept).ToPython(); scope.Set(nameof(tuple), tuple); @@ -57,12 +66,15 @@ static void ConversionsObject() { } [Test] - public void TupleRoundtripObject() { + public void TupleRoundtripObject() + { TupleRoundtripObject, ValueTuple>(); } - static void TupleRoundtripObject() { + static void TupleRoundtripObject() + { var tuple = Activator.CreateInstance(typeof(T), 42, "42", new object()); - using (Py.GIL()) { + using (Py.GIL()) + { var pyTuple = TupleCodec.Instance.TryEncode(tuple); Assert.IsTrue(TupleCodec.Instance.TryDecode(pyTuple, out object restored)); Assert.AreEqual(expected: tuple, actual: restored); @@ -70,17 +82,108 @@ static void TupleRoundtripObject() { } [Test] - public void TupleRoundtripGeneric() { + public void TupleRoundtripGeneric() + { TupleRoundtripGeneric, ValueTuple>(); } - static void TupleRoundtripGeneric() { + static void TupleRoundtripGeneric() + { var tuple = Activator.CreateInstance(typeof(T), 42, "42", new object()); - using (Py.GIL()) { + using (Py.GIL()) + { var pyTuple = TupleCodec.Instance.TryEncode(tuple); Assert.IsTrue(TupleCodec.Instance.TryDecode(pyTuple, out T restored)); Assert.AreEqual(expected: tuple, actual: restored); } } + + [Test] + public void FunctionAction() + { + var codec = FunctionCodec.Instance; + + PyInt x = new PyInt(1); + PyDict y = new PyDict(); + //non-callables can't be decoded into Action + Assert.IsFalse(codec.CanDecode(x, typeof(Action))); + Assert.IsFalse(codec.CanDecode(y, typeof(Action))); + + var locals = new PyDict(); + PythonEngine.Exec(@" +def foo(): + return 1 +def bar(a): + return 2 +", null, locals.Handle); + + //foo, the function with no arguments + var fooFunc = locals.GetItem("foo"); + Assert.IsFalse(codec.CanDecode(fooFunc, typeof(bool))); + + //CanDecode does not work for variadic actions + //Assert.IsFalse(codec.CanDecode(fooFunc, typeof(Action))); + Assert.IsTrue(codec.CanDecode(fooFunc, typeof(Action))); + + Action fooAction; + Assert.IsTrue(codec.TryDecode(fooFunc, out fooAction)); + Assert.DoesNotThrow(() => fooAction()); + + //bar, the function with an argument + var barFunc = locals.GetItem("bar"); + Assert.IsFalse(codec.CanDecode(barFunc, typeof(bool))); + //Assert.IsFalse(codec.CanDecode(barFunc, typeof(Action))); + Assert.IsTrue(codec.CanDecode(barFunc, typeof(Action))); + + Action barAction; + Assert.IsTrue(codec.TryDecode(barFunc, out barAction)); + Assert.DoesNotThrow(() => barAction(new[] { (object)true })); + } + + [Test] + public void FunctionFunc() + { + var codec = FunctionCodec.Instance; + + PyInt x = new PyInt(1); + PyDict y = new PyDict(); + //non-callables can't be decoded into Func + Assert.IsFalse(codec.CanDecode(x, typeof(Func))); + Assert.IsFalse(codec.CanDecode(y, typeof(Func))); + + var locals = new PyDict(); + PythonEngine.Exec(@" +def foo(): + return 1 +def bar(a): + return 2 +", null, locals.Handle); + + //foo, the function with no arguments + var fooFunc = locals.GetItem("foo"); + Assert.IsFalse(codec.CanDecode(fooFunc, typeof(bool))); + + //CanDecode does not work for variadic actions + //Assert.IsFalse(codec.CanDecode(fooFunc, typeof(Func))); + Assert.IsTrue(codec.CanDecode(fooFunc, typeof(Func))); + + Func foo; + Assert.IsTrue(codec.TryDecode(fooFunc, out foo)); + object res1 = null; + Assert.DoesNotThrow(() => res1 = foo()); + Assert.AreEqual(res1, 1); + + //bar, the function with an argument + var barFunc = locals.GetItem("bar"); + Assert.IsFalse(codec.CanDecode(barFunc, typeof(bool))); + //Assert.IsFalse(codec.CanDecode(barFunc, typeof(Func))); + Assert.IsTrue(codec.CanDecode(barFunc, typeof(Func))); + + Func bar; + Assert.IsTrue(codec.TryDecode(barFunc, out bar)); + object res2 = null; + Assert.DoesNotThrow(() => res2 = bar(new[] { (object)true })); + Assert.AreEqual(res2, 2); + } } } diff --git a/src/embed_tests/TestCallbacks.cs b/src/embed_tests/TestCallbacks.cs index 220b0a86a..eb23c74d0 100644 --- a/src/embed_tests/TestCallbacks.cs +++ b/src/embed_tests/TestCallbacks.cs @@ -31,5 +31,77 @@ public void TestNoOverloadException() { StringAssert.EndsWith(expectedArgTypes, error.Message); } } + + private class Callables + { + internal object CallFunction0(Func func) + { + return func(); + } + + internal object CallFunction1(Func func, object arg) + { + return func(new[] { arg}); + } + + internal void CallAction0(Action func) + { + func(); + } + + internal void CallAction1(Action func, object arg) + { + func(new[] { arg }); + } + } + + [Test] + public void TestPythonFunctionPassedIntoCLRMethod() + { + var locals = new PyDict(); + PythonEngine.Exec(@" +def ret_1(): + return 1 +def str_len(a): + return len(a) +", null, locals.Handle); + + var ret1 = locals.GetItem("ret_1"); + var strLen = locals.GetItem("str_len"); + + var callables = new Callables(); + + Python.Runtime.Codecs.FunctionCodec.Register(); + + //ret1. A function with no arguments that returns an integer + //it must be convertible to Action or Func and not to Func + { + Assert.IsTrue(Converter.ToManaged(ret1.Handle, typeof(Action), out var result1, false)); + Assert.IsTrue(Converter.ToManaged(ret1.Handle, typeof(Func), out var result2, false)); + + Assert.DoesNotThrow(() => { callables.CallAction0((Action)result1); }); + object ret2 = null; + Assert.DoesNotThrow(() => { ret2 = callables.CallFunction0((Func)result2); }); + Assert.AreEqual(ret2, 1); + } + + //strLen. A function that takes something with a __len__ and returns the result of that function + //It must be convertible to an Action and Func) and not to an Action or Func + { + Assert.IsTrue(Converter.ToManaged(strLen.Handle, typeof(Action), out var result3, false)); + Assert.IsTrue(Converter.ToManaged(strLen.Handle, typeof(Func), out var result4, false)); + + //try using both func and action to show you can get __len__ of a string but not an integer + Assert.Throws(() => { callables.CallAction1((Action)result3, 2); }); + Assert.DoesNotThrow(() => { callables.CallAction1((Action)result3, "hello"); }); + Assert.Throws(() => { callables.CallFunction1((Func)result4, 2); }); + + object ret2 = null; + Assert.DoesNotThrow(() => { ret2 = callables.CallFunction1((Func)result4, "hello"); }); + Assert.AreEqual(ret2, 5); + } + + PyObjectConversions.Reset(); + } } } diff --git a/src/runtime/Codecs/FunctionCodec.cs b/src/runtime/Codecs/FunctionCodec.cs new file mode 100644 index 000000000..ba75ecf60 --- /dev/null +++ b/src/runtime/Codecs/FunctionCodec.cs @@ -0,0 +1,191 @@ +using System; +using System.Reflection; + +namespace Python.Runtime.Codecs +{ + //converts python functions to C# actions + class FunctionCodec : IPyObjectDecoder + { + private static int GetNumArgs(PyObject pyCallable) + { + var locals = new PyDict(); + locals.SetItem("f", pyCallable); + using (Py.GIL()) + PythonEngine.Exec(@" +from inspect import signature +try: + x = len(signature(f).parameters) +except: + x = 0 +", null, locals.Handle); + + var x = locals.GetItem("x"); + return new PyInt(x).ToInt32(); + } + + private static int GetNumArgs(Type targetType) + { + MethodInfo invokeMethod = targetType.GetMethod("Invoke"); + return invokeMethod.GetParameters().Length; + } + + private static bool IsUnaryAction(Type targetType) + { + return targetType == typeof(Action); + } + + private static bool IsVariadicObjectAction(Type targetType) + { + return targetType == typeof(Action); + } + + private static bool IsUnaryFunc(Type targetType) + { + return targetType == typeof(Func); + } + + private static bool IsVariadicObjectFunc(Type targetType) + { + return targetType == typeof(Func); + } + + private static bool IsAction(Type targetType) + { + return IsUnaryAction(targetType) || IsVariadicObjectAction(targetType); + } + + private static bool IsFunc(Type targetType) + { + return IsUnaryFunc(targetType) || IsVariadicObjectFunc(targetType); + } + + private static bool IsCallable(Type targetType) + { + return ClassManager.IsDelegate(targetType); + } + + public static FunctionCodec Instance { get; } = new FunctionCodec(); + public bool CanDecode(PyObject objectType, Type targetType) + { + //python object must be callable + if (!objectType.IsCallable()) return false; + + //C# object must be callable + if (!IsCallable(targetType)) + return false; + + return true; + } + + private static object ConvertUnaryAction(PyObject pyObj) + { + Func func = (Func)ConvertUnaryFunc(pyObj); + Action action = () => { func(); }; + return (object)action; + } + + private static object ConvertVariadicObjectAction(PyObject pyObj, int numArgs) + { + Func func = (Func)ConvertVariadicObjectFunc(pyObj, numArgs); + Action action = (object[] args) => { func(args); }; + return (object)action; + } + + //TODO share code between ConvertUnaryFunc and ConvertVariadicObjectFunc + private static object ConvertUnaryFunc(PyObject pyObj) + { + Func func = () => + { + Runtime.XIncref(pyObj.Handle); + PyObject pyAction = new PyObject(pyObj.Handle); + var pyArgs = new PyObject[0]; + using (Py.GIL()) + { + var pyResult = pyAction.Invoke(pyArgs); + Runtime.XIncref(pyResult.Handle); + Converter.ToManaged(pyResult.Handle, typeof(object), out var result, true); + return result; + } + }; + return (object)func; + } + + private static object ConvertVariadicObjectFunc(PyObject pyObj, int numArgs) + { + Func func = (object[] o) => + { + Runtime.XIncref(pyObj.Handle); + PyObject pyAction = new PyObject(pyObj.Handle); + var pyArgs = new PyObject[numArgs]; + int i = 0; + foreach (object obj in o) + { + pyArgs[i++] = new PyObject(Converter.ToPython(obj)); + } + + using (Py.GIL()) + { + var pyResult = pyAction.Invoke(pyArgs); + Runtime.XIncref(pyResult.Handle); + object result; + Converter.ToManaged(pyResult.Handle, typeof(object), out result, true); + return result; + } + }; + return (object)func; + } + + public bool TryDecode(PyObject pyObj, out T value) + { + value = default(T); + var tT = typeof(T); + if (!IsCallable(tT)) + return false; + + var numArgs = GetNumArgs(pyObj); + if (numArgs != GetNumArgs(tT)) + return false; + + if (IsAction(tT)) + { + object actionObj = null; + if (numArgs == 0) + { + actionObj = ConvertUnaryAction(pyObj); + } + else + { + actionObj = ConvertVariadicObjectAction(pyObj, numArgs); + } + + value = (T)actionObj; + return true; + } + else if (IsFunc(tT)) + { + + object funcObj = null; + if (numArgs == 0) + { + funcObj = ConvertUnaryFunc(pyObj); + } + else + { + funcObj = ConvertVariadicObjectFunc(pyObj, numArgs); + } + + value = (T)funcObj; + return true; + } + else + { + return false; + } + } + + public static void Register() + { + PyObjectConversions.RegisterDecoder(Instance); + } + } +} diff --git a/src/runtime/classmanager.cs b/src/runtime/classmanager.cs index 0b084a49d..aad768536 100644 --- a/src/runtime/classmanager.cs +++ b/src/runtime/classmanager.cs @@ -39,6 +39,11 @@ public static void Reset() cache = new Dictionary(128); } + public static bool IsDelegate(Type type) + { + return type.IsSubclassOf(dtype); + } + /// /// Return the ClassBase-derived instance that implements a particular /// reflected managed type, creating it if it doesn't yet exist. @@ -83,7 +88,7 @@ private static ClassBase CreateClass(Type type) impl = new GenericType(type); } - else if (type.IsSubclassOf(dtype)) + else if (IsDelegate(type)) { impl = new DelegateObject(type); }