forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder.cc
More file actions
173 lines (147 loc) · 7.7 KB
/
Copy pathencoder.cc
File metadata and controls
173 lines (147 loc) · 7.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#include "module.h"
#include <ctranslate2/encoder.h>
#include "replica_pool.h"
namespace ctranslate2 {
namespace python {
class EncoderWrapper : public ReplicaPoolHelper<Encoder> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;
EncoderForwardOutput
forward_batch(const std::variant<BatchTokens, BatchIds, StorageView>& inputs,
const std::optional<StorageView>& lengths,
const std::optional<BatchIds>& token_type_ids) {
std::future<EncoderForwardOutput> future;
std::shared_lock lock(_mutex);
assert_model_is_ready();
switch (inputs.index()) {
case 0:
future = _pool->forward_batch_async(
std::get<BatchTokens>(inputs),
token_type_ids.value_or(std::vector<std::vector<size_t>>()));
break;
case 1:
future = _pool->forward_batch_async(
std::get<BatchIds>(inputs),
token_type_ids.value_or(std::vector<std::vector<size_t>>()));
break;
case 2:
if (!lengths)
throw std::invalid_argument("lengths vector is required when passing a dense input");
future = _pool->forward_batch_async(
std::get<StorageView>(inputs),
lengths.value(),
token_type_ids.value_or(std::vector<std::vector<size_t>>()));
break;
}
return future.get();
}
};
void register_encoder(py::module& m) {
py::class_<EncoderForwardOutput>(m, "EncoderForwardOutput",
"Forward output of an encoder model.")
.def_readonly("last_hidden_state", &EncoderForwardOutput::last_hidden_state,
"Output of the last layer.")
.def_readonly("pooler_output", &EncoderForwardOutput::pooler_output,
"Output of the pooling layer.")
.def("__repr__", [](const EncoderForwardOutput& output) {
return "EncoderForwardOutput(last_hidden_state="
+ std::string(py::repr(py::cast(output.last_hidden_state)))
+ ", pooler_output=" + std::string(py::repr(py::cast(output.pooler_output)))
+ ")";
})
;
py::class_<EncoderWrapper>(
m, "Encoder",
R"pbdoc(
A text encoder.
Example:
>>> encoder = ctranslate2.Encoder("model/", device="cpu")
>>> encoder.forward_batch([["▁Hello", "▁world", "!"]])
)pbdoc")
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("flash_attention")=false,
py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes the encoder.
Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this encoder on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Maximum number of parallel generations.
intra_threads: Number of OpenMP threads per encoder (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
flash_attention: run model with flash attention 2 for self-attention layer
tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")
.def_property_readonly("device", &EncoderWrapper::device,
"Device this encoder is running on.")
.def_property_readonly("device_index", &EncoderWrapper::device_index,
"List of device IDs where this encoder is running on.")
.def_property_readonly("compute_type", &EncoderWrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_encoders", &EncoderWrapper::num_replicas,
"Number of encoders backing this instance.")
.def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("tensor_parallel", &EncoderWrapper::tensor_parallel,
"Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &EncoderWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
.def("forward_batch", &EncoderWrapper::forward_batch,
py::arg("inputs"),
py::arg("lengths")=py::none(),
py::arg("token_type_ids")=py::none(),
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Forwards a batch of sequences in the encoder.
Arguments:
inputs: A batch of sequences either as string tokens or token IDs.
This argument can also be a dense int32 array with shape
``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor).
lengths: The length of each sequence as a int32 array with shape
``[batch_size]``. Required when :obj:`inputs` is a dense array.
token_type_ids: A batch of token type IDs of same shape as :obj:`inputs`.
``[batch_size, max_length]``.
Returns:
The encoder model output.
)pbdoc")
.def("unload_model", &EncoderWrapper::unload_model,
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Unloads the model attached to this encoder but keep enough runtime context
to quickly resume encoder on the initial device.
Arguments:
to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
)pbdoc")
.def("load_model", &EncoderWrapper::load_model,
py::arg("keep_cache")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Loads the model back to the initial device.
Arguments:
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
)pbdoc")
.def_property_readonly("model_is_loaded", &EncoderWrapper::model_is_loaded,
"Whether the model is loaded on the initial device and ready to be used.")
;
}
}
}