From 67a2fcd6285d280ebf0deaed50c26920322170fa Mon Sep 17 00:00:00 2001 From: Superpiffer Date: Mon, 6 Feb 2023 12:42:56 +0100 Subject: [PATCH] Use a local Status variable Using a local reference ensure that the Status object cannot be disposed before the Dispose. This way it's also possible to use an external Status instance instead of the static one, if needed. --- .../Sessions/BaseSession.cs | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 095187b9a..4e131b365 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -30,6 +30,7 @@ namespace Tensorflow public class BaseSession : DisposableObject { protected Graph _graph; + protected Status _status; public Graph graph => _graph; public BaseSession(IntPtr handle, Graph g) @@ -48,9 +49,9 @@ public BaseSession(string target = "", Graph g = null, ConfigProto config = null } using var opts = new SessionOptions(target, config); - status = status ?? tf.Status; - _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); - status.Check(true); + _status = status ?? tf.Status; + _handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle); + _status.Check(true); } public virtual void run(Operation op, params FeedItem[] feed_dict) @@ -217,8 +218,6 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] f // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); - var status = tf.Status; - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); c_api.TF_SessionRun(_handle, @@ -232,9 +231,9 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] f target_opers: target_list.Select(f => (IntPtr)f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, - status: status.Handle); + status: _status.Handle); - status.Check(true); + _status.Check(true); var result = new NDArray[fetch_list.Length]; @@ -246,8 +245,6 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] f public unsafe Tensor eval(Tensor tensor) { - var status = tf.Status; - var output_values = new IntPtr[1]; var fetch_list = new[] { tensor._as_tf_output() }; @@ -262,9 +259,9 @@ public unsafe Tensor eval(Tensor tensor) target_opers: new IntPtr[0], ntargets: 0, run_metadata: IntPtr.Zero, - status: status.Handle); + status: _status.Handle); - status.Check(true); + _status.Check(true); return new Tensor(new SafeTensorHandle(output_values[0])); } @@ -291,15 +288,7 @@ private void _extend_graph() protected override void DisposeUnmanagedResources(IntPtr handle) { // c_api.TF_CloseSession(handle, tf.Status.Handle); - if (tf.Status == null || tf.Status.Handle.IsInvalid) - { - using var status = new Status(); - c_api.TF_DeleteSession(handle, status.Handle); - } - else - { - c_api.TF_DeleteSession(handle, tf.Status.Handle); - } + c_api.TF_DeleteSession(handle, _status.Handle); } } }