Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3d943a1

Browse files
committed
IDataset cardinality.
1 parent a994a86 commit 3d943a1

File tree

10 files changed

+155
-33
lines changed

10 files changed

+155
-33
lines changed

src/TensorFlowNET.Core/APIs/tf.data.cs

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ public partial class tensorflow
2323
public class DataOps
2424
{
2525
public int AUTOTUNE = -1;
26+
public int INFINITE_CARDINALITY = -1;
27+
public int UNKNOWN_CARDINALITY = -2;
2628
public DatasetManager Dataset { get; } = new DatasetManager();
2729
}
2830
}

src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs

+26-26
Original file line numberDiff line numberDiff line change
@@ -29,48 +29,48 @@ namespace Tensorflow.Contexts
2929
/// </summary>
3030
public sealed partial class Context
3131
{
32-
// [DebuggerStepThrough]
33-
public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args)
32+
Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args)
3433
{
35-
Func<Tensors> graphAction = () =>
34+
var keywords = new Dictionary<string, object>();
35+
if (args.OpInputArgs != null)
3636
{
37-
var keywords = new Dictionary<string, object>();
38-
if(args.OpInputArgs != null)
39-
{
40-
foreach (var (i, input) in enumerate(args.OpInputArgs))
41-
keywords[$"input_{i}"] = input;
42-
}
37+
foreach (var (i, input) in enumerate(args.OpInputArgs))
38+
keywords[$"input_{i}"] = input;
39+
}
4340

44-
if(args.OpAttrs != null)
45-
{
46-
foreach (var attr in args.OpAttrs)
47-
keywords[attr.Key] = attr.Value;
48-
}
41+
if (args.OpAttrs != null)
42+
{
43+
foreach (var attr in args.OpAttrs)
44+
keywords[attr.Key] = attr.Value;
45+
}
4946

50-
return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
51-
};
47+
return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
48+
}
5249

53-
Func<Tensors> eagerAction = () =>
50+
Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args)
51+
{
52+
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
5453
{
55-
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
56-
{
57-
attrs = args.OpAttrs
58-
};
59-
return tf.Runner.TFE_FastPathExecute(opExecInfo);
54+
attrs = args.OpAttrs
6055
};
56+
return tf.Runner.TFE_FastPathExecute(opExecInfo);
57+
}
6158

59+
// [DebuggerStepThrough]
60+
public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args)
61+
{
6262
if (tf.Context.has_graph_arg(args.OpInputArgs))
6363
{
6464
if (executing_eagerly())
6565
{
6666
graph_mode();
67-
var result = graphAction();
67+
var result = ExecGraphAction(opType, name, args);
6868
restore_mode();
6969
return result;
7070
}
7171
else
7272
{
73-
var result = graphAction();
73+
var result = ExecGraphAction(opType, name, args);
7474
if (tf.Runner.MustRecordGradient())
7575
{
7676
var op = result[0].op;
@@ -92,14 +92,14 @@ public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args)
9292
args1[i + 1] = arg.Value;
9393
i += 2;
9494
}
95-
tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs);
95+
tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs);
9696
}
9797
return result;
9898
}
9999
}
100100
else
101101
{
102-
return eagerAction();
102+
return ExecEagerAction(opType, name, args);
103103
}
104104
}
105105
}

src/TensorFlowNET.Core/Data/DatasetV2.cs

+10-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
7070
num_parallel_calls: num_parallel_calls,
7171
preserve_cardinality: true);
7272

73+
public IDatasetV2 filter(Func<Tensors, Tensors> predicate_func)
74+
=> new FilterDataset(this, predicate_func);
75+
76+
public IDatasetV2 filter(Func<Tensor, bool> predicate_func)
77+
=> new FilterDataset(this, predicate_func);
78+
7379
public OwnedIterator make_one_shot_iterator()
7480
{
7581
if (tf.Context.executing_eagerly())
@@ -105,13 +111,15 @@ public IDatasetV2 apply_options()
105111
// (3) Apply graph rewrite options
106112
var graph_rewrites = new[]
107113
{
108-
"noop_elimination",
109114
"map_and_batch_fusion",
115+
"map_parallelization",
116+
"noop_elimination",
110117
"shuffle_and_repeat_fusion"
111118
};
112119
var graph_rewrite_configs = new string[]
113120
{
114121
"autotune_buffer_sizes:autotune:true",
122+
"batch_parallelization:autotune:true",
115123
"disable_prefetch_legacy_autotune:autotune:true",
116124
"enable_gradient_descent:autotune:true",
117125
"map_parallelization:autotune:true"
@@ -124,7 +132,7 @@ public IDatasetV2 apply_options()
124132
return dataset;
125133
}
126134

127-
public Tensor dataset_cardinality(string name = null)
135+
public Tensor cardinality(string name = null)
128136
=> tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));
129137

130138
public override string ToString()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using System;
2+
using Tensorflow.Functions;
3+
using static Tensorflow.Binding;
4+
5+
namespace Tensorflow
6+
{
7+
/// <summary>
8+
/// A `Dataset` that filters its input according to a predicate function.
9+
/// </summary>
10+
public class FilterDataset : UnaryDataset
11+
{
12+
public FilterDataset(IDatasetV2 input_dataset,
13+
Func<Tensor, bool> predicate_func) : base(input_dataset)
14+
{
15+
Func<Tensors, Tensors> predicate_func_update = x =>
16+
{
17+
var result = predicate_func(x);
18+
return constant_op.constant(result);
19+
};
20+
21+
var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");
22+
func.Enter();
23+
var inputs = new Tensors();
24+
foreach (var input in input_dataset.element_spec)
25+
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
26+
var outputs = predicate_func_update(inputs);
27+
func.ToGraph(inputs, outputs);
28+
func.Exit();
29+
30+
structure = func.OutputStructure;
31+
32+
variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
33+
func,
34+
output_types,
35+
output_shapes);
36+
}
37+
38+
public FilterDataset(IDatasetV2 input_dataset,
39+
Func<Tensors, Tensors> predicate_func) : base(input_dataset)
40+
{
41+
var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");
42+
func.Enter();
43+
var inputs = new Tensors();
44+
foreach (var input in input_dataset.element_spec)
45+
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
46+
var outputs = predicate_func(inputs);
47+
func.ToGraph(inputs, outputs);
48+
func.Exit();
49+
50+
structure = func.OutputStructure;
51+
52+
variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
53+
func,
54+
output_types,
55+
output_shapes);
56+
}
57+
}
58+
}

src/TensorFlowNET.Core/Data/IDatasetV2.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ IDatasetV2 map(Func<Tensors, Tensors> map_func,
7272
IDatasetV2 map(Func<Tensors, Tensors> map_func,
7373
int num_parallel_calls);
7474

75+
IDatasetV2 filter(Func<Tensors, Tensors> map_func);
76+
IDatasetV2 filter(Func<Tensor, bool> map_func);
77+
7578
OwnedIterator make_one_shot_iterator();
7679

7780
IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);
@@ -91,6 +94,6 @@ IDatasetV2 map(Func<Tensors, Tensors> map_func,
9194
/// </summary>
9295
/// <param name="name"></param>
9396
/// <returns></returns>
94-
Tensor dataset_cardinality(string name = null);
97+
Tensor cardinality(string name = null);
9598
}
9699
}

src/TensorFlowNET.Core/Operations/dataset_ops.cs

+19
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,25 @@ public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] outp
249249
preserve_cardinality
250250
}));
251251

252+
/// <summary>
253+
/// Creates a dataset containing elements of `input_dataset` matching `predicate`.
254+
/// </summary>
255+
/// <param name="dataset"></param>
256+
/// <param name="predicate"></param>
257+
/// <param name="output_types"></param>
258+
/// <param name="output_shapes"></param>
259+
/// <param name="name"></param>
260+
/// <returns></returns>
261+
public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, TensorShape[] output_shapes,
262+
string name = null)
263+
=> tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0])
264+
.SetAttributes(new
265+
{
266+
predicate,
267+
output_types,
268+
output_shapes
269+
}));
270+
252271
/// <summary>
253272
/// Creates a dataset that applies `f` to the outputs of `input_dataset`.
254273
/// </summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Runtime.CompilerServices;
3+
4+
namespace Tensorflow
5+
{
6+
public partial class Tensor
7+
{
8+
public static Tensor operator !=(Tensor x, int y)
9+
=> gen_math_ops.not_equal(x, math_ops.cast(y, dtype: x.dtype));
10+
public static Tensor operator ==(Tensor x, int y)
11+
=> gen_math_ops.equal(x, math_ops.cast(y, dtype: x.dtype));
12+
}
13+
}

src/TensorFlowNET.Core/Tensors/constant_op.cs

+6
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF
144144
break;
145145
}
146146
}
147+
else if (dtype != TF_DataType.DtInvalid &&
148+
value is NDArray nd &&
149+
dtypes.as_dtype(nd.dtype) != dtype)
150+
{
151+
value = nd.astype(dtype.as_numpy_dtype());
152+
}
147153

148154
if (dtype == TF_DataType.TF_STRING && value is byte[] bytes)
149155
return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING);

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
8787
if (adapter_steps > -1)
8888
return adapter_steps;
8989

90-
var size = dataset.dataset_cardinality();
90+
var size = dataset.cardinality();
9191
return size.numpy();
9292
}
9393

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+16-3
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ public void Cache()
147147
public void Cardinality()
148148
{
149149
var dataset = tf.data.Dataset.range(10);
150-
var cardinality = dataset.dataset_cardinality();
150+
var cardinality = dataset.cardinality();
151151
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
152152
dataset = dataset.map(x => x[0] + 1);
153-
cardinality = dataset.dataset_cardinality();
153+
cardinality = dataset.cardinality();
154154
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
155155
}
156156

@@ -159,10 +159,23 @@ public void CardinalityWithAutoTune()
159159
{
160160
var dataset = tf.data.Dataset.range(10);
161161
dataset = dataset.map(x => x, num_parallel_calls: -1);
162-
var cardinality = dataset.dataset_cardinality();
162+
var cardinality = dataset.cardinality();
163163
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
164164
}
165165

166+
[TestMethod]
167+
public void CardinalityWithRepeat()
168+
{
169+
var dataset = tf.data.Dataset.range(10);
170+
dataset = dataset.repeat();
171+
var cardinality = dataset.cardinality();
172+
Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy());
173+
174+
dataset = dataset.filter(x => true);
175+
cardinality = dataset.cardinality();
176+
Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy());
177+
}
178+
166179
[TestMethod]
167180
public void Shuffle()
168181
{

0 commit comments

Comments
 (0)