|
| 1 | +/* |
| 2 | + * Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ======================================================================= |
| 16 | + */ |
| 17 | + |
| 18 | +package org.tensorflow.types; |
| 19 | + |
| 20 | +import java.util.function.Consumer; |
| 21 | +import org.tensorflow.SparseTensor; |
| 22 | +import org.tensorflow.Tensor; |
| 23 | +import org.tensorflow.exceptions.TensorFlowException; |
| 24 | +import org.tensorflow.internal.types.TUint16Mapper; |
| 25 | +import org.tensorflow.ndarray.NdArray; |
| 26 | +import org.tensorflow.ndarray.Shape; |
| 27 | +import org.tensorflow.ndarray.ShortNdArray; |
| 28 | +import org.tensorflow.ndarray.StdArrays; |
| 29 | +import org.tensorflow.ndarray.buffer.ShortDataBuffer; |
| 30 | +import org.tensorflow.proto.framework.DataType; |
| 31 | +import org.tensorflow.types.annotation.TensorType; |
| 32 | +import org.tensorflow.types.family.TIntegral; |
| 33 | + |
| 34 | +/** 16-bit unsigned integer tensor type. */ |
| 35 | +@TensorType(dataType = DataType.DT_UINT16, byteSize = 2, mapperClass = TUint16Mapper.class) |
| 36 | +public interface TUint16 extends ShortNdArray, TIntegral { |
| 37 | + |
| 38 | + /** |
| 39 | + * Allocates a new tensor for storing a single short value. |
| 40 | + * |
| 41 | + * @param value short to store in the new tensor |
| 42 | + * @return the new tensor |
| 43 | + */ |
| 44 | + static TUint16 scalarOf(short value) { |
| 45 | + return Tensor.of(TUint16.class, Shape.scalar(), data -> data.setShort(value)); |
| 46 | + } |
| 47 | + |
| 48 | + /** |
| 49 | + * Allocates a new tensor for storing a vector of shorts. |
| 50 | + * |
| 51 | + * @param values short to store in the new tensor |
| 52 | + * @return the new tensor |
| 53 | + */ |
| 54 | + static TUint16 vectorOf(short... values) { |
| 55 | + if (values == null) { |
| 56 | + throw new IllegalArgumentException(); |
| 57 | + } |
| 58 | + return Tensor.of( |
| 59 | + TUint16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); |
| 60 | + } |
| 61 | + |
| 62 | + /** |
| 63 | + * Allocates a new tensor which is a copy of a given array of shorts. |
| 64 | + * |
| 65 | + * <p>The tensor will have the same shape as the source array and its data will be copied. |
| 66 | + * |
| 67 | + * @param src the source array giving the shape and data to the new tensor |
| 68 | + * @return the new tensor |
| 69 | + */ |
| 70 | + static TUint16 tensorOf(NdArray<Short> src) { |
| 71 | + return Tensor.of(TUint16.class, src.shape(), src::copyTo); |
| 72 | + } |
| 73 | + |
| 74 | + /** |
| 75 | + * Allocates a new tensor of the given shape. |
| 76 | + * |
| 77 | + * @param shape shape of the tensor to allocate |
| 78 | + * @return the new tensor |
| 79 | + */ |
| 80 | + static TUint16 tensorOf(Shape shape) { |
| 81 | + return Tensor.of(TUint16.class, shape); |
| 82 | + } |
| 83 | + |
| 84 | + /** |
| 85 | + * Allocates a new tensor of the given shape, initialized with the provided data. |
| 86 | + * |
| 87 | + * @param shape shape of the tensor to allocate |
| 88 | + * @param data buffer of shorts to initialize the tensor with |
| 89 | + * @return the new tensor |
| 90 | + */ |
| 91 | + static TUint16 tensorOf(Shape shape, ShortDataBuffer data) { |
| 92 | + return Tensor.of(TUint16.class, shape, d -> d.write(data)); |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * Allocates a new tensor of the given shape and initialize its data. |
| 97 | + * |
| 98 | + * @param shape shape of the tensor to allocate |
| 99 | + * @param dataInit tensor data initializer |
| 100 | + * @return the new tensor |
| 101 | + * @throws TensorFlowException if the tensor cannot be allocated or initialized |
| 102 | + */ |
| 103 | + static TUint16 tensorOf(Shape shape, Consumer<TUint16> dataInit) { |
| 104 | + return Tensor.of(TUint16.class, shape, dataInit); |
| 105 | + } |
| 106 | + |
| 107 | + /** |
| 108 | + * Create a sparse tensors from {@code indices}, {@code values} and {@code denseShape} dense |
| 109 | + * tensors, with a default value of zero. |
| 110 | + * |
| 111 | + * <p>The returned instance also implements the {@link SparseTensor SparseTensor<TUint16>} |
| 112 | + * interface, allowing a user to access directly the dense tensors when needed. |
| 113 | + * |
| 114 | + * @param indices A 2-D tensor of shape {@code [N, ndims]}, that specifies the indices of the |
| 115 | + * elements in the sparse tensor that contain non-default values (elements are zero-indexed). |
| 116 | + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of |
| 117 | + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. |
| 118 | + * @param values A 1-D tensor of shape {@code [N]}, which supplies the values for each element in |
| 119 | + * indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code |
| 120 | + * values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse tensor has a value of |
| 121 | + * {@code 18}, and element {@code [2,4,0]} of the tensor has a value of {@code 3}. |
| 122 | + * @param denseShape A 1-D tensor of shape {@code [ndims]} where each the value at index {@code i} |
| 123 | + * represents the size of dimension {@code i} in a dense version of that tensor. |
| 124 | + * @return the new sparse tensor |
| 125 | + * @see SparseTensor for more details on sparse tensors and how to release their memory properly |
| 126 | + */ |
| 127 | + static TUint16 sparseTensorOf(TInt64 indices, TUint16 values, TInt64 denseShape) { |
| 128 | + return SparseTensor.of(indices, values, denseShape).asTypedTensor(); |
| 129 | + } |
| 130 | +} |
0 commit comments