diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 095187b9a..4e131b365 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -30,6 +30,7 @@ namespace Tensorflow public class BaseSession : DisposableObject { protected Graph _graph; + protected Status _status; public Graph graph => _graph; public BaseSession(IntPtr handle, Graph g) @@ -48,9 +49,9 @@ public BaseSession(string target = "", Graph g = null, ConfigProto config = null } using var opts = new SessionOptions(target, config); - status = status ?? tf.Status; - _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); - status.Check(true); + _status = status ?? tf.Status; + _handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle); + _status.Check(true); } public virtual void run(Operation op, params FeedItem[] feed_dict) @@ -217,8 +218,6 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] f // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); - var status = tf.Status; - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); c_api.TF_SessionRun(_handle, @@ -232,9 +231,9 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] f target_opers: target_list.Select(f => (IntPtr)f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, - status: status.Handle); + status: _status.Handle); - status.Check(true); + _status.Check(true); var result = new NDArray[fetch_list.Length]; @@ -246,8 +245,6 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] f public unsafe Tensor eval(Tensor tensor) { - var status = tf.Status; - var output_values = new IntPtr[1]; var fetch_list = new[] { tensor._as_tf_output() }; @@ -262,9 +259,9 @@ public unsafe Tensor eval(Tensor tensor) target_opers: new IntPtr[0], ntargets: 0, run_metadata: IntPtr.Zero, - status: status.Handle); + status: _status.Handle); - status.Check(true); + _status.Check(true); return new Tensor(new SafeTensorHandle(output_values[0])); } @@ -291,15 +288,7 @@ private void _extend_graph() protected override void DisposeUnmanagedResources(IntPtr handle) { // c_api.TF_CloseSession(handle, tf.Status.Handle); - if (tf.Status == null || tf.Status.Handle.IsInvalid) - { - using var status = new Status(); - c_api.TF_DeleteSession(handle, status.Handle); - } - else - { - c_api.TF_DeleteSession(handle, tf.Status.Handle); - } + c_api.TF_DeleteSession(handle, _status.Handle); } } }