Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9cfea86

Browse files
authored
Add TUint16 type (tensorflow#469)
1 parent 936e379 commit 9cfea86

File tree

5 files changed

+348
-32
lines changed

5 files changed

+348
-32
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =======================================================================
16+
*/
17+
package org.tensorflow.internal.types;
18+
19+
import org.bytedeco.javacpp.PointerScope;
20+
import org.tensorflow.RawTensor;
21+
import org.tensorflow.SparseTensor;
22+
import org.tensorflow.TensorMapper;
23+
import org.tensorflow.internal.buffer.TensorBuffers;
24+
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
25+
import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray;
26+
import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray;
27+
import org.tensorflow.proto.framework.DataType;
28+
import org.tensorflow.types.TInt64;
29+
import org.tensorflow.types.TUint16;
30+
31+
/**
32+
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_Uint16} tensors to a
33+
* n-dimensional data space.
34+
*/
35+
public final class TUint16Mapper extends TensorMapper<TUint16> {
36+
37+
@Override
38+
protected TUint16 mapDense(RawTensor tensor) {
39+
ShortDataBuffer buffer = TensorBuffers.toShorts(nativeHandle(tensor));
40+
return new DenseTUint16(tensor, buffer);
41+
}
42+
43+
@Override
44+
protected SparseTensor<TUint16> mapSparse(
45+
TInt64 indices, TUint16 values, TInt64 denseShape, PointerScope tensorScope) {
46+
return new TUint16Mapper.SparseTUint16(indices, values, denseShape, tensorScope);
47+
}
48+
49+
private static final class DenseTUint16 extends ShortDenseNdArray implements TUint16 {
50+
51+
@Override
52+
public Class<TUint16> type() {
53+
return TUint16.class;
54+
}
55+
56+
@Override
57+
public DataType dataType() {
58+
return asRawTensor().dataType();
59+
}
60+
61+
@Override
62+
public long numBytes() {
63+
return asRawTensor().numBytes();
64+
}
65+
66+
@Override
67+
public void close() {
68+
asRawTensor().close();
69+
}
70+
71+
@Override
72+
public RawTensor asRawTensor() {
73+
return rawTensor;
74+
}
75+
76+
final RawTensor rawTensor;
77+
78+
DenseTUint16(RawTensor rawTensor, ShortDataBuffer buffer) {
79+
super(buffer, rawTensor.shape());
80+
this.rawTensor = rawTensor;
81+
}
82+
}
83+
84+
private static final class SparseTUint16 extends ShortSparseNdArray
85+
implements TUint16, SparseTensor<TUint16> {
86+
87+
@Override
88+
public Class<TUint16> type() {
89+
return TUint16.class;
90+
}
91+
92+
@Override
93+
public DataType dataType() {
94+
return values().dataType();
95+
}
96+
97+
@Override
98+
public long numBytes() {
99+
return SparseHelpers.numBytes(this);
100+
}
101+
102+
@Override
103+
public void close() {
104+
tensorScope.close();
105+
}
106+
107+
@Override
108+
public boolean isSparse() {
109+
return true;
110+
}
111+
112+
@Override
113+
public TInt64 indices() {
114+
return (TInt64) getIndices();
115+
}
116+
117+
@Override
118+
public TUint16 values() {
119+
return (TUint16) getValues();
120+
}
121+
122+
@Override
123+
public TInt64 denseShape() {
124+
return denseShape;
125+
}
126+
127+
SparseTUint16(TInt64 indices, TUint16 values, TInt64 denseShape, PointerScope tensorScope) {
128+
super(indices, values, (short) 0, SparseHelpers.toDimensionalSpace(denseShape));
129+
this.denseShape = denseShape;
130+
this.tensorScope = tensorScope.extend();
131+
}
132+
133+
private final TInt64 denseShape;
134+
private final PointerScope tensorScope;
135+
}
136+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828
import org.tensorflow.types.TInt32;
2929
import org.tensorflow.types.TInt64;
3030
import org.tensorflow.types.TString;
31+
import org.tensorflow.types.TUint16;
3132
import org.tensorflow.types.TUint8;
3233
import org.tensorflow.types.annotation.TensorType;
3334
import org.tensorflow.types.family.TType;
3435

35-
/**
36-
* Repository of all registered tensor types.
37-
*/
36+
/** Repository of all registered tensor types. */
3837
public final class TensorTypeRegistry {
3938

4039
/**
@@ -47,9 +46,10 @@ public final class TensorTypeRegistry {
4746
public static <T extends TType> TensorTypeInfo<T> find(DataType dataType) {
4847
TensorTypeInfo<?> typeInfo = TYPES_BY_CODE.get(dataType.getNumber());
4948
if (typeInfo == null) {
50-
throw new IllegalArgumentException("No tensor type has been registered for data type " + dataType);
49+
throw new IllegalArgumentException(
50+
"No tensor type has been registered for data type " + dataType);
5151
}
52-
return (TensorTypeInfo<T>)typeInfo;
52+
return (TensorTypeInfo<T>) typeInfo;
5353
}
5454

5555
/**
@@ -62,28 +62,37 @@ public static <T extends TType> TensorTypeInfo<T> find(DataType dataType) {
6262
public static <T extends TType> TensorTypeInfo<T> find(Class<T> type) {
6363
TensorTypeInfo<?> typeInfo = TYPES_BY_CLASS.get(type);
6464
if (typeInfo == null) {
65-
throw new IllegalArgumentException("Class \"" + type.getName() + "\" is not registered as a tensor type");
65+
throw new IllegalArgumentException(
66+
"Class \"" + type.getName() + "\" is not registered as a tensor type");
6667
}
67-
return (TensorTypeInfo<T>)typeInfo;
68+
return (TensorTypeInfo<T>) typeInfo;
6869
}
6970

7071
private static final Map<Integer, TensorTypeInfo<?>> TYPES_BY_CODE = new HashMap<>();
71-
private static final Map<Class<? extends TType>, TensorTypeInfo<?>> TYPES_BY_CLASS = new HashMap<>();
72+
private static final Map<Class<? extends TType>, TensorTypeInfo<?>> TYPES_BY_CLASS =
73+
new HashMap<>();
7274

7375
private static <T extends TType> void register(Class<T> type) {
7476
TensorType typeAnnot = type.getDeclaredAnnotation(TensorType.class);
7577
if (typeAnnot == null) {
76-
throw new IllegalArgumentException("Class \"" + type.getName() + "\" must be annotated "
77-
+ "with @TensorType to be registered as a tensor type");
78+
throw new IllegalArgumentException(
79+
"Class \""
80+
+ type.getName()
81+
+ "\" must be annotated "
82+
+ "with @TensorType to be registered as a tensor type");
7883
}
7984
TensorMapper<T> mapper;
8085
try {
81-
mapper = (TensorMapper<T>)typeAnnot.mapperClass().newInstance();
86+
mapper = (TensorMapper<T>) typeAnnot.mapperClass().newInstance();
8287
} catch (ReflectiveOperationException e) {
83-
throw new IllegalArgumentException("Class \"" + type.getName() + "\" must have a public "
84-
+ "parameter-less constructor to be used as a tensor mapper");
88+
throw new IllegalArgumentException(
89+
"Class \""
90+
+ type.getName()
91+
+ "\" must have a public "
92+
+ "parameter-less constructor to be used as a tensor mapper");
8593
}
86-
TensorTypeInfo<T> typeInfo = new TensorTypeInfo<>(type, typeAnnot.dataType(), typeAnnot.byteSize(), mapper);
94+
TensorTypeInfo<T> typeInfo =
95+
new TensorTypeInfo<>(type, typeAnnot.dataType(), typeAnnot.byteSize(), mapper);
8796
TYPES_BY_CLASS.put(type, typeInfo);
8897
TYPES_BY_CODE.put(typeInfo.dataType().getNumber(), typeInfo);
8998
TYPES_BY_CODE.put(typeInfo.dataType().getNumber() + 100, typeInfo);
@@ -100,6 +109,7 @@ private static <T extends TType> void register(Class<T> type) {
100109
register(TInt64.class);
101110
register(TString.class);
102111
register(TUint8.class);
112+
register(TUint16.class);
103113
register(TBfloat16.class);
104114
}
105115
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =======================================================================
16+
*/
17+
18+
package org.tensorflow.types;
19+
20+
import java.util.function.Consumer;
21+
import org.tensorflow.SparseTensor;
22+
import org.tensorflow.Tensor;
23+
import org.tensorflow.exceptions.TensorFlowException;
24+
import org.tensorflow.internal.types.TUint16Mapper;
25+
import org.tensorflow.ndarray.NdArray;
26+
import org.tensorflow.ndarray.Shape;
27+
import org.tensorflow.ndarray.ShortNdArray;
28+
import org.tensorflow.ndarray.StdArrays;
29+
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
30+
import org.tensorflow.proto.framework.DataType;
31+
import org.tensorflow.types.annotation.TensorType;
32+
import org.tensorflow.types.family.TIntegral;
33+
34+
/** 16-bit unsigned integer tensor type. */
35+
@TensorType(dataType = DataType.DT_UINT16, byteSize = 2, mapperClass = TUint16Mapper.class)
36+
public interface TUint16 extends ShortNdArray, TIntegral {
37+
38+
/**
39+
* Allocates a new tensor for storing a single short value.
40+
*
41+
* @param value short to store in the new tensor
42+
* @return the new tensor
43+
*/
44+
static TUint16 scalarOf(short value) {
45+
return Tensor.of(TUint16.class, Shape.scalar(), data -> data.setShort(value));
46+
}
47+
48+
/**
49+
* Allocates a new tensor for storing a vector of shorts.
50+
*
51+
* @param values short to store in the new tensor
52+
* @return the new tensor
53+
*/
54+
static TUint16 vectorOf(short... values) {
55+
if (values == null) {
56+
throw new IllegalArgumentException();
57+
}
58+
return Tensor.of(
59+
TUint16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data));
60+
}
61+
62+
/**
63+
* Allocates a new tensor which is a copy of a given array of shorts.
64+
*
65+
* <p>The tensor will have the same shape as the source array and its data will be copied.
66+
*
67+
* @param src the source array giving the shape and data to the new tensor
68+
* @return the new tensor
69+
*/
70+
static TUint16 tensorOf(NdArray<Short> src) {
71+
return Tensor.of(TUint16.class, src.shape(), src::copyTo);
72+
}
73+
74+
/**
75+
* Allocates a new tensor of the given shape.
76+
*
77+
* @param shape shape of the tensor to allocate
78+
* @return the new tensor
79+
*/
80+
static TUint16 tensorOf(Shape shape) {
81+
return Tensor.of(TUint16.class, shape);
82+
}
83+
84+
/**
85+
* Allocates a new tensor of the given shape, initialized with the provided data.
86+
*
87+
* @param shape shape of the tensor to allocate
88+
* @param data buffer of shorts to initialize the tensor with
89+
* @return the new tensor
90+
*/
91+
static TUint16 tensorOf(Shape shape, ShortDataBuffer data) {
92+
return Tensor.of(TUint16.class, shape, d -> d.write(data));
93+
}
94+
95+
/**
96+
* Allocates a new tensor of the given shape and initialize its data.
97+
*
98+
* @param shape shape of the tensor to allocate
99+
* @param dataInit tensor data initializer
100+
* @return the new tensor
101+
* @throws TensorFlowException if the tensor cannot be allocated or initialized
102+
*/
103+
static TUint16 tensorOf(Shape shape, Consumer<TUint16> dataInit) {
104+
return Tensor.of(TUint16.class, shape, dataInit);
105+
}
106+
107+
/**
108+
* Create a sparse tensors from {@code indices}, {@code values} and {@code denseShape} dense
109+
* tensors, with a default value of zero.
110+
*
111+
* <p>The returned instance also implements the {@link SparseTensor SparseTensor<TUint16>}
112+
* interface, allowing a user to access directly the dense tensors when needed.
113+
*
114+
* @param indices A 2-D tensor of shape {@code [N, ndims]}, that specifies the indices of the
115+
* elements in the sparse tensor that contain non-default values (elements are zero-indexed).
116+
* For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of
117+
* {@code [1,3,1]} and {@code [2,4,0]} have non-default values.
118+
* @param values A 1-D tensor of shape {@code [N]}, which supplies the values for each element in
119+
* indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code
120+
* values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse tensor has a value of
121+
* {@code 18}, and element {@code [2,4,0]} of the tensor has a value of {@code 3}.
122+
* @param denseShape A 1-D tensor of shape {@code [ndims]} where each the value at index {@code i}
123+
* represents the size of dimension {@code i} in a dense version of that tensor.
124+
* @return the new sparse tensor
125+
* @see SparseTensor for more details on sparse tensors and how to release their memory properly
126+
*/
127+
static TUint16 sparseTensorOf(TInt64 indices, TUint16 values, TInt64 denseShape) {
128+
return SparseTensor.of(indices, values, denseShape).asTypedTensor();
129+
}
130+
}

0 commit comments

Comments
 (0)