diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 794c32673..93a54af00 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -269,7 +269,7 @@ public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataTyp public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null) => tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num) - .SetAttributes(new { axis })); + .SetAttributes(new { axis, num })); public static Tensor where(Tensor condition, string name = null) { diff --git a/test/TensorFlowNET.Keras.UnitTest/Gradient.cs b/test/TensorFlowNET.Keras.UnitTest/Gradient.cs new file mode 100644 index 000000000..159800e10 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Gradient.cs @@ -0,0 +1,76 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Tensorflow; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.NumPy; + +namespace TensorFlowNET.Keras.UnitTest; + +[TestClass] +public class GradientTest +{ + public Model get_actor(int num_states) + { + var inputs = keras.layers.Input(shape: num_states); + var outputs = keras.layers.Dense(1, activation: keras.activations.Tanh).Apply(inputs); + + Model model = keras.Model(inputs, outputs); + + return model; + } + + public Model get_critic(int num_states, int num_actions) + { + // State as input + var state_input = keras.layers.Input(shape: num_states); + + // Action as input + var action_input = keras.layers.Input(shape: num_actions); + + var concat = keras.layers.Concatenate(axis: 1).Apply(new Tensors(state_input, action_input)); + + var outputs = keras.layers.Dense(1).Apply(concat); + + Model model = keras.Model(new Tensors(state_input, action_input), outputs); + model.summary(); + + return model; + } + + [TestMethod] + public void GetGradient_Test() + { + var numStates = 3; + var numActions = 1; + var batchSize = 64; + var gamma = 0.99f; + + var target_actor_model = get_actor(numStates); + var target_critic_model = get_critic(numStates, numActions); + var critic_model = get_critic(numStates, numActions); + + Tensor state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT); + Tensor action_batch = tf.convert_to_tensor(np.zeros((batchSize, numActions)), TF_DataType.TF_FLOAT); + Tensor reward_batch = tf.convert_to_tensor(np.zeros((batchSize, 1)), TF_DataType.TF_FLOAT); + Tensor next_state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT); + + using (var tape = tf.GradientTape()) + { + var target_actions = target_actor_model.Apply(next_state_batch, training: true); + var target_critic_value = target_critic_model.Apply(new Tensors(new Tensor[] { next_state_batch, target_actions }), training: true); + + var y = reward_batch + tf.multiply(gamma, target_critic_value); + + var critic_value = critic_model.Apply(new Tensors(new Tensor[] { state_batch, action_batch }), training: true); + + var critic_loss = math_ops.reduce_mean(math_ops.square(y - critic_value)); + + var critic_grad = tape.gradient(critic_loss, critic_model.TrainableVariables); + + Assert.IsNotNull(critic_grad); + Assert.IsNotNull(critic_grad.First()); + } + } +} \ No newline at end of file