forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtranslator.h
More file actions
100 lines (81 loc) · 3.76 KB
/
Copy pathtranslator.h
File metadata and controls
100 lines (81 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#pragma once
#include <string>
#include <vector>
#include "models/model.h"
#include "translation_result.h"
namespace ctranslate2 {
struct TranslationOptions {
// Maximum batch size to run the model on (set 0 to forward the input as is).
// When more inputs are passed to translate(), they will be internally sorted by length
// to increase efficiency.
size_t max_batch_size = 0;
// Beam size to use for beam search (set 1 to run greedy search).
size_t beam_size = 2;
// Length penalty value to apply during beam search.
float length_penalty = 0;
// Decoding length constraints.
size_t max_decoding_length = 250;
size_t min_decoding_length = 1;
// Randomly sample from the top K candidates (not compatible with beam search, set to 0
// to sample from the full output distribution).
size_t sampling_topk = 1;
// High temperature increase randomness.
float sampling_temperature = 1;
// Use the vocabulary map included in the model directory.
bool use_vmap = false;
// Number of hypotheses to store in the TranslationResult class (should be smaller than
// beam_size).
size_t num_hypotheses = 1;
// Store attention vectors in the TranslationResult class.
bool return_attention = false;
};
// This class holds all information required to translate from a model. Copying
// a Translator instance does not duplicate the model data and the copy can
// be safely executed in parallel.
class Translator {
public:
Translator(const std::string& model_dir, Device device = Device::CPU, int device_index = 0);
Translator(const std::shared_ptr<const models::Model>& model);
Translator(const Translator& other);
TranslationResult
translate(const std::vector<std::string>& tokens);
TranslationResult
translate(const std::vector<std::string>& tokens,
const TranslationOptions& options);
TranslationResult
translate_with_prefix(const std::vector<std::string>& source,
const std::vector<std::string>& target_prefix,
const TranslationOptions& options);
std::vector<TranslationResult>
translate_batch(const std::vector<std::vector<std::string>>& tokens);
std::vector<TranslationResult>
translate_batch(const std::vector<std::vector<std::string>>& tokens,
const TranslationOptions& options);
std::vector<TranslationResult>
translate_batch_with_prefix(const std::vector<std::vector<std::string>>& source,
const std::vector<std::vector<std::string>>& target_prefix,
const TranslationOptions& options);
Device device() const;
int device_index() const;
ComputeType compute_type() const;
// Change only the model while keeping the same device and compute type.
void set_model(const std::string& model_dir);
void set_model(const std::shared_ptr<const models::Model>& model);
private:
void make_graph();
std::vector<TranslationResult>
run_batch_translation_sorted(const std::vector<std::vector<std::string>>& source,
const TranslationOptions& options);
std::vector<TranslationResult>
run_batch_translation(const std::vector<std::vector<std::string>>& source,
const std::vector<std::vector<std::string>>* target_prefix,
const TranslationOptions& options);
TranslationResult
run_translation(const std::vector<std::string>& source,
const std::vector<std::string>* target_prefix,
const TranslationOptions& options);
std::shared_ptr<const models::Model> _model;
std::unique_ptr<layers::Encoder> _encoder;
std::unique_ptr<layers::Decoder> _decoder;
};
}