Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 68df1b7

Browse files
committed
tf.data.Dataset skip() #446
1 parent a174a84 commit 68df1b7

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

src/TensorFlowNET.Core/Data/DatasetV2.cs

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ public IDatasetV2 shard(int num_shards, int index)
4141
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
4242
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
4343

44+
public IDatasetV2 skip(int count)
45+
=> new SkipDataset(this, count);
46+
4447
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
4548
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
4649

src/TensorFlowNET.Core/Data/IDatasetV2.cs

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
3434

3535
IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);
3636

37+
/// <summary>
38+
/// Creates a `Dataset` that skips `count` elements from this dataset.
39+
/// </summary>
40+
/// <param name="count"></param>
41+
/// <returns></returns>
42+
IDatasetV2 skip(int count);
43+
3744
IDatasetV2 batch(int batch_size, bool drop_remainder = false);
3845

3946
IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null);
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
/// <summary>
9+
/// A `Dataset` skipping the first `count` elements from its input.
10+
/// </summary>
11+
public class SkipDataset : UnaryUnchangedStructureDataset
12+
{
13+
Tensor _count;
14+
15+
public SkipDataset(IDatasetV2 input_dataset,
16+
int count) : base(input_dataset)
17+
{
18+
_count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count");
19+
variant_tensor = ops.skip_dataset(input_dataset.variant_tensor,
20+
_count,
21+
output_types, output_shapes);
22+
}
23+
}
24+
}

src/TensorFlowNET.Core/Operations/dataset_ops.cs

+18
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,24 @@ public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
106106
throw new NotImplementedException("");
107107
}
108108

109+
public Tensor skip_dataset(Tensor input_dataset, Tensor count,
110+
TF_DataType[] output_types, TensorShape[] output_shapes,
111+
string name = null)
112+
{
113+
if (tf.Context.executing_eagerly())
114+
{
115+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
116+
"SkipDataset", name,
117+
null,
118+
input_dataset, count,
119+
"output_types", output_types,
120+
"output_shapes", output_shapes);
121+
return results[0];
122+
}
123+
124+
throw new NotImplementedException("");
125+
}
126+
109127
public Tensor dummy_seed_generator(string name = null)
110128
{
111129
if (tf.Context.executing_eagerly())

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+15
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,20 @@ public void Shard()
8484
value += 3;
8585
}
8686
}
87+
88+
[TestMethod]
89+
public void Skip()
90+
{
91+
long value = 7;
92+
93+
var dataset = tf.data.Dataset.range(10);
94+
dataset = dataset.skip(7);
95+
96+
foreach (var item in dataset)
97+
{
98+
Assert.AreEqual(value, (long)item.Item1);
99+
value ++;
100+
}
101+
}
87102
}
88103
}

0 commit comments

Comments
 (0)