Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix issue #760 #992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/gen_array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataTyp

public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null)
=> tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num)
.SetAttributes(new { axis }));
.SetAttributes(new { axis, num }));

public static Tensor where(Tensor condition, string name = null)
{
Expand Down
76 changes: 76 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Gradient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;

namespace TensorFlowNET.Keras.UnitTest;

[TestClass]
public class GradientTest
{
public Model get_actor(int num_states)
{
var inputs = keras.layers.Input(shape: num_states);
var outputs = keras.layers.Dense(1, activation: keras.activations.Tanh).Apply(inputs);

Model model = keras.Model(inputs, outputs);

return model;
}

public Model get_critic(int num_states, int num_actions)
{
// State as input
var state_input = keras.layers.Input(shape: num_states);

// Action as input
var action_input = keras.layers.Input(shape: num_actions);

var concat = keras.layers.Concatenate(axis: 1).Apply(new Tensors(state_input, action_input));

var outputs = keras.layers.Dense(1).Apply(concat);

Model model = keras.Model(new Tensors(state_input, action_input), outputs);
model.summary();

return model;
}

[TestMethod]
public void GetGradient_Test()
{
var numStates = 3;
var numActions = 1;
var batchSize = 64;
var gamma = 0.99f;

var target_actor_model = get_actor(numStates);
var target_critic_model = get_critic(numStates, numActions);
var critic_model = get_critic(numStates, numActions);

Tensor state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
Tensor action_batch = tf.convert_to_tensor(np.zeros((batchSize, numActions)), TF_DataType.TF_FLOAT);
Tensor reward_batch = tf.convert_to_tensor(np.zeros((batchSize, 1)), TF_DataType.TF_FLOAT);
Tensor next_state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);

using (var tape = tf.GradientTape())
{
var target_actions = target_actor_model.Apply(next_state_batch, training: true);
var target_critic_value = target_critic_model.Apply(new Tensors(new Tensor[] { next_state_batch, target_actions }), training: true);

var y = reward_batch + tf.multiply(gamma, target_critic_value);

var critic_value = critic_model.Apply(new Tensors(new Tensor[] { state_batch, action_batch }), training: true);

var critic_loss = math_ops.reduce_mean(math_ops.square(y - critic_value));

var critic_grad = tape.gradient(critic_loss, critic_model.TrainableVariables);

Assert.IsNotNull(critic_grad);
Assert.IsNotNull(critic_grad.First());
}
}
}