Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b54cbaa

Browse files
committed
Fix binary_accuracy for keras.
1 parent e007e86 commit b54cbaa

File tree

6 files changed

+28
-11
lines changed

6 files changed

+28
-11
lines changed

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>2.2.0</TargetTensorFlow>
8-
<Version>0.40.0</Version>
8+
<Version>0.40.1</Version>
99
<LangVersion>8.0</LangVersion>
1010
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
1111
<Company>SciSharp STACK</Company>
@@ -19,7 +19,7 @@
1919
<Description>Google's TensorFlow full binding in .NET Standard.
2020
Building, training and infering deep learning models.
2121
https://tensorflownet.readthedocs.io</Description>
22-
<AssemblyVersion>0.40.0.0</AssemblyVersion>
22+
<AssemblyVersion>0.40.1.0</AssemblyVersion>
2323
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.
2424

2525
* Eager Mode is added finally.
@@ -32,7 +32,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works.
3232
Keras API is a separate package released as TensorFlow.Keras.
3333

3434
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes>
35-
<FileVersion>0.40.0.0</FileVersion>
35+
<FileVersion>0.40.1.0</FileVersion>
3636
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3737
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
3838
<SignAssembly>true</SignAssembly>

src/TensorFlowNET.Core/Tensors/tensor_util.cs

-2
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,6 @@ public static string to_numpy_string(Tensor tensor)
596596
case TF_DataType.TF_STRING:
597597
return string.Join(string.Empty, nd.ToArray<byte>()
598598
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
599-
case TF_DataType.TF_BOOL:
600-
return nd.GetBoolean(0).ToString();
601599
case TF_DataType.TF_VARIANT:
602600
case TF_DataType.TF_RESOURCE:
603601
return "<unprintable>";

src/TensorFlowNET.Keras/BackendImpl.cs

+8
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ public void manual_variable_initialization(bool value)
137137
{
138138
_MANUAL_VAR_INIT = value;
139139
}
140+
141+
public Tensor mean(Tensor x, int axis = -1, bool keepdims = false)
142+
{
143+
if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL)
144+
x = math_ops.cast(x, TF_DataType.TF_FLOAT);
145+
return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false);
146+
}
147+
140148
public GraphLearningPhase learning_phase()
141149
{
142150
var graph = tf.get_default_graph();

src/TensorFlowNET.Keras/Engine/MetricsContainer.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
6868
bool is_binary = y_p_last_dim == 1;
6969
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;
7070

71-
if (is_sparse_categorical)
71+
if (is_binary)
72+
metric_obj = keras.metrics.binary_accuracy;
73+
else if (is_sparse_categorical)
7274
metric_obj = keras.metrics.sparse_categorical_accuracy;
7375
else
7476
metric_obj = keras.metrics.categorical_accuracy;

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
namespace Tensorflow.Keras.Metrics
1+
using static Tensorflow.KerasApi;
2+
3+
namespace Tensorflow.Keras.Metrics
24
{
35
public class MetricsApi
46
{
7+
public Tensor binary_accuracy(Tensor y_true, Tensor y_pred)
8+
{
9+
float threshold = 0.5f;
10+
y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype);
11+
return keras.backend.mean(math_ops.equal(y_true, y_pred), axis: -1);
12+
}
13+
514
public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred)
615
{
716
var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1));

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
<LangVersion>8.0</LangVersion>
77
<RootNamespace>Tensorflow.Keras</RootNamespace>
88
<Platforms>AnyCPU;x64</Platforms>
9-
<Version>0.5.0</Version>
9+
<Version>0.5.1</Version>
1010
<Authors>Haiping Chen</Authors>
1111
<Product>Keras for .NET</Product>
12-
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
12+
<Copyright>Apache 2.0, Haiping Chen 2021</Copyright>
1313
<PackageId>TensorFlow.Keras</PackageId>
1414
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl>
1515
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
@@ -35,8 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
3535
<RepositoryType>Git</RepositoryType>
3636
<SignAssembly>true</SignAssembly>
3737
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
38-
<AssemblyVersion>0.5.0.0</AssemblyVersion>
39-
<FileVersion>0.5.0.0</FileVersion>
38+
<AssemblyVersion>0.5.1.0</AssemblyVersion>
39+
<FileVersion>0.5.1.0</FileVersion>
4040
<PackageLicenseFile>LICENSE</PackageLicenseFile>
4141
</PropertyGroup>
4242

0 commit comments

Comments
 (0)