forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstorage_view.cc
More file actions
242 lines (197 loc) · 8.39 KB
/
Copy pathstorage_view.cc
File metadata and controls
242 lines (197 loc) · 8.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#include "module.h"
#include <sstream>
#include <ctranslate2/storage_view.h>
#include "utils.h"
using namespace pybind11::literals;
namespace ctranslate2 {
namespace python {
static DataType typestr_to_dtype(const std::string& typestr) {
const auto type_code = typestr[1];
const auto num_bytes = typestr[2];
if (type_code == 'i') {
if (num_bytes == '1')
return DataType::INT8;
if (num_bytes == '2')
return DataType::INT16;
if (num_bytes == '4')
return DataType::INT32;
} else if (type_code == 'f') {
if (num_bytes == '2')
return DataType::FLOAT16;
if (num_bytes == '4')
return DataType::FLOAT32;
}
throw std::invalid_argument("Unsupported type: " + typestr);
}
static std::string dtype_to_typestr(const DataType dtype) {
// Assume little-endian.
switch (dtype) {
case DataType::FLOAT32:
return "<f4";
case DataType::FLOAT16:
return "<f2";
case DataType::INT8:
return "|i1";
case DataType::INT16:
return "<i2";
case DataType::INT32:
return "<i4";
default:
throw std::runtime_error("Data type " + dtype_name(dtype) + " is not supported in the array "
"interface. You should first convert this storage to a supported "
"type, for example with "
"`storage = storage.to(ctranslate2.DataType.float32)`");
}
}
static StorageView create_view_from_array(py::object array) {
auto device = Device::CPU;
py::object interface_obj = py::getattr(array, "__array_interface__", py::none());
if (interface_obj.is_none()) {
interface_obj = py::getattr(array, "__cuda_array_interface__", py::none());
if (interface_obj.is_none())
throw std::invalid_argument("Object does not implement the array interface");
device = Device::CUDA;
}
py::dict interface = interface_obj.cast<py::dict>();
if (interface_obj.contains("strides") && !interface_obj["strides"].is_none())
throw std::invalid_argument("StorageView does not support arrays with non contiguous memory");
auto shape = interface["shape"].cast<Shape>();
auto dtype = typestr_to_dtype(interface["typestr"].cast<std::string>());
auto data = interface["data"].cast<py::tuple>();
auto ptr = data[0].cast<uintptr_t>();
auto read_only = data[1].cast<bool>();
if (read_only)
throw std::invalid_argument("StorageView does not support read-only arrays");
StorageView view(dtype, device);
view.view((void*)ptr, std::move(shape));
return view;
}
static py::dict get_array_interface(const StorageView& view) {
py::tuple shape(view.rank());
for (size_t i = 0; i < shape.size(); ++i)
shape[i] = view.dim(i);
return py::dict(
"shape"_a=shape,
"typestr"_a=dtype_to_typestr(view.dtype()),
"data"_a=py::make_tuple((uintptr_t)view.buffer(), false),
"version"_a=3);
}
void register_storage_view(py::module& m) {
py::enum_<DataType>(m, "DataType")
.value("float32", DataType::FLOAT32)
.value("float16", DataType::FLOAT16)
.value("bfloat16", DataType::BFLOAT16)
.value("int8", DataType::INT8)
.value("int16", DataType::INT16)
.value("int32", DataType::INT32)
;
py::enum_<Device>(m, "Device")
.value("cpu", Device::CPU)
.value("cuda", Device::CUDA)
;
py::class_<StorageView>(
m, "StorageView",
R"pbdoc(
An allocated buffer with shape information.
The object implements the
`Array Interface <https://numpy.org/doc/stable/reference/arrays.interface.html>`_
and the
`CUDA Array Interface <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_
so that it can be passed to Numpy or PyTorch without copy.
Example:
>>> x = np.ones((2, 4), dtype=np.int32)
>>> y = ctranslate2.StorageView.from_array(x)
>>> print(y)
1 1 1 ... 1 1 1
[cpu:0 int32 storage viewed as 2x4]
>>> z = np.array(y)
...
>>> x = torch.ones((2, 4), dtype=torch.int32, device="cuda")
>>> y = ctranslate2.StorageView.from_array(x)
>>> print(y)
1 1 1 ... 1 1 1
[cuda:0 int32 storage viewed as 2x4]
>>> z = torch.as_tensor(y, device="cuda")
)pbdoc")
.def_static("from_array", &create_view_from_array, py::arg("array"),
py::keep_alive<0, 1>(),
R"pbdoc(
Creates a ``StorageView`` from an object implementing the array interface.
Arguments:
array: An object implementing the array interface (e.g. a Numpy array
or a PyTorch Tensor).
Returns:
A new ``StorageView`` instance sharing the same data as the input array.
Raises:
ValueError: if the object does not implement the array interface or
uses an unsupported array specification.
)pbdoc")
.def_property_readonly("dtype", &StorageView::dtype,
"Data type used by the storage.")
.def_property_readonly("shape", &StorageView::shape,
"Shape of the storage view.")
.def_property_readonly("device_index", &StorageView::device_index,
"Device index.")
.def_property_readonly("device",
[](const StorageView& view) {
return device_to_str(view.device());
},
"Device where the storage is allocated (\"cpu\" or \"cuda\").")
.def_property_readonly("__array_interface__", [](const StorageView& view) {
if (view.device() == Device::CUDA)
throw py::attribute_error("Cannot get __array_interface__ when the StorageView "
"is viewing a CUDA array");
return get_array_interface(view);
})
.def_property_readonly("__cuda_array_interface__", [](const StorageView& view) {
if (view.device() == Device::CPU)
throw py::attribute_error("Cannot get __cuda_array_interface__ when the StorageView "
"is viewing a CPU array");
return get_array_interface(view);
})
.def("__str__", [](const StorageView& view) {
std::ostringstream stream;
stream << view;
return stream.str();
})
.def("__repr__", [](const StorageView& view) {
std::ostringstream stream;
stream << view;
return stream.str();
})
.def("to",
[](const StorageView& view, DataType dtype) {
ScopedDeviceSetter device_setter(view.device(), view.device_index());
StorageView converted = view.to(dtype);
synchronize_stream(view.device());
return converted;
},
py::arg("dtype"),
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Converts the storage to another type.
Arguments:
dtype: The data type to convert to.
Returns:
A new ``StorageView`` instance.
)pbdoc")
.def("to_device",
[](const StorageView& view, Device device) {
ScopedDeviceSetter device_setter(view.device(), view.device_index());
StorageView converted = view.to(device);
synchronize_stream(view.device());
return converted;
},
py::arg("device"),
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Converts the storage to another device.
Arguments:
device: The device to copy the data to.
Returns:
A new ``StorageView`` instance.
)pbdoc")
;
}
}
}