Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8d0bd50

Browse files
Lee ReidOceania2018
Lee Reid
authored andcommitted
Added tests with ignore attribute
1 parent 30da91c commit 8d0bd50

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

+76
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Collections.Generic;
34
using System.Linq;
45
using Tensorflow;
56
using Tensorflow.UnitTest;
@@ -24,6 +25,81 @@ public void ConstantSquare()
2425
Assert.AreEqual((float)grad, 3.0f);
2526
}
2627

28+
[Ignore]
29+
[TestMethod]
30+
public void SquaredDifference_Constant()
31+
{
32+
// Calcute the gradient of (x1-x2)^2
33+
// by Automatic Differentiation in Eager mode
34+
var x1 = tf.constant(7f);
35+
var x2 = tf.constant(11f);
36+
37+
// Sanity check
38+
using (var tape = tf.GradientTape())
39+
{
40+
tape.watch(x2);
41+
var loss = tf.multiply((x1 - x2), (x1 - x2));
42+
43+
var result = tape.gradient(loss, x2);
44+
// Expected is 2*(11-7) = 8
45+
Assert.AreEqual((float)result, 8f);
46+
}
47+
48+
// Actual test
49+
using (var tape = tf.GradientTape())
50+
{
51+
tape.watch(x2);
52+
var loss = tf.squared_difference(x1, x2);
53+
54+
// Expected is 2*(11-7) = 8
55+
var result = tape.gradient(loss, x2);
56+
Assert.AreEqual((float)result, 8f);
57+
}
58+
}
59+
60+
61+
[Ignore]
62+
[TestMethod]
63+
public void SquaredDifference_1D()
64+
{
65+
// Calcute the gradient of (x1-x2)^2
66+
// by Automatic Differentiation in Eager mode
67+
// Expected is 2*(abs(x1-x2))
68+
Tensor x1 = new NumSharp.NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
69+
Tensor x2 = new NumSharp.NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
70+
float[] expected = new float[] {
71+
(29-1) * 2,
72+
(27-3) * 2,
73+
(23-5) * 2,
74+
(7-21) * 2,
75+
(11-19) * 2,
76+
(13-17) * 2
77+
};
78+
79+
// Sanity check
80+
using (var tape = tf.GradientTape())
81+
{
82+
tape.watch(x1);
83+
tape.watch(x2);
84+
var loss = tf.multiply((x1 - x2), (x1 - x2));
85+
86+
var result = tape.gradient(loss, x2);
87+
CollectionAssert.AreEqual(result.ToArray<float>(), expected);
88+
}
89+
90+
// Actual test
91+
using (var tape = tf.GradientTape())
92+
{
93+
tape.watch(x1);
94+
tape.watch(x2);
95+
var loss = tf.squared_difference(x1, x2);
96+
97+
var result = tape.gradient(loss, x2);
98+
CollectionAssert.AreEqual(result.ToArray<float>(), expected);
99+
}
100+
}
101+
102+
27103
/// <summary>
28104
/// Calcute the gradient of w * w * w
29105
/// 高阶梯度

0 commit comments

Comments
 (0)