-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy pathexecutor.h
More file actions
1969 lines (1650 loc) · 92.2 KB
/
executor.h
File metadata and controls
1969 lines (1650 loc) · 92.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/runtimeDefaults.h"
#include <chrono>
#include <cstdint>
#include <deque>
#include <filesystem>
#include <list>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
namespace tensorrt_llm::mpi
{
class MpiComm;
} // namespace tensorrt_llm::mpi
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class BaseKVCacheManager;
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::executor
{
using SizeType32 = tensorrt_llm::runtime::SizeType32;
/// @brief Version of TRT-LLM
char const* version() noexcept;
class Model;
class Serialization;
class DataTransceiverState;
/// @brief Sampling configuration
class SamplingConfig
{
public:
/// @brief Constructor for SamplingConfig
/// See description of parameters below
explicit SamplingConfig(SizeType32 beamWidth = 1, std::optional<SizeType32> const& topK = std::nullopt,
std::optional<FloatType> const& topP = std::nullopt, std::optional<FloatType> const& topPMin = std::nullopt,
std::optional<TokenIdType> const& topPResetIds = std::nullopt,
std::optional<FloatType> const& topPDecay = std::nullopt,
std::optional<RandomSeedType> const& seed = std::nullopt,
std::optional<FloatType> const& temperature = std::nullopt,
std::optional<SizeType32> const& minTokens = std::nullopt,
std::optional<FloatType> const& beamSearchDiversityRate = std::nullopt,
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
std::optional<FloatType> const& presencePenalty = std::nullopt,
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
std::optional<SizeType32> const& promptIgnoreLength = std::nullopt,
std::optional<FloatType> const& lengthPenalty = std::nullopt,
std::optional<SizeType32> const& earlyStopping = std::nullopt,
std::optional<SizeType32> const& noRepeatNgramSize = std::nullopt,
std::optional<SizeType32> const& numReturnSequences = std::nullopt,
std::optional<FloatType> const& minP = std::nullopt,
std::optional<std::vector<SizeType32>> const& beamWidthArray = std::nullopt);
bool operator==(SamplingConfig const& other) const;
[[nodiscard]] SizeType32 getBeamWidth() const;
[[nodiscard]] SizeType32 getNumReturnBeams() const;
[[nodiscard]] std::optional<SizeType32> getTopK() const;
[[nodiscard]] std::optional<FloatType> getTopP() const;
[[nodiscard]] std::optional<FloatType> getTopPMin() const;
[[nodiscard]] std::optional<SizeType32> getTopPResetIds() const;
[[nodiscard]] std::optional<FloatType> getTopPDecay() const;
[[nodiscard]] std::optional<RandomSeedType> getSeed() const;
[[nodiscard]] std::optional<FloatType> getTemperature() const;
[[nodiscard]] std::optional<SizeType32> getMinTokens() const;
[[nodiscard]] std::optional<FloatType> getBeamSearchDiversityRate() const;
[[nodiscard]] std::optional<FloatType> getRepetitionPenalty() const;
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
[[nodiscard]] std::optional<SizeType32> getPromptIgnoreLength() const;
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
[[nodiscard]] std::optional<SizeType32> getNoRepeatNgramSize() const;
[[nodiscard]] std::optional<SizeType32> getNumReturnSequences() const;
[[nodiscard]] std::optional<FloatType> getMinP() const;
[[nodiscard]] std::optional<std::vector<SizeType32>> getBeamWidthArray() const;
void setBeamWidth(SizeType32 beamWidth);
void setTopK(std::optional<SizeType32> const& topK);
void setTopP(std::optional<FloatType> const& topP);
void setTopPMin(std::optional<FloatType> const& topPMin);
void setTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
void setTopPDecay(std::optional<FloatType> const& topPDecay);
void setSeed(std::optional<RandomSeedType> const& seed);
void setTemperature(std::optional<FloatType> const& temperature);
void setMinTokens(std::optional<SizeType32> const& minTokens);
void setBeamSearchDiversityRate(std::optional<FloatType> const& beamSearchDiversityRate);
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
void setPromptIgnoreLength(std::optional<SizeType32> const& promptIgnoreLength);
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
void setNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
void setNumReturnSequences(std::optional<SizeType32> const& numReturnSequences);
void setMinP(std::optional<FloatType> const& minP);
void setBeamWidthArray(std::optional<std::vector<SizeType32>> const& beamWidthArray);
private:
static SizeType32 checkBeamWidth(SizeType32 beamWidth);
static std::optional<FloatType> const& checkTopK(std::optional<FloatType> const& topK);
static std::optional<FloatType> const& checkTopP(std::optional<FloatType> const& topP);
static std::optional<FloatType> const& checkTopPMin(std::optional<FloatType> const& topPMin);
static std::optional<TokenIdType> const& checkTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
static std::optional<FloatType> const& checkTopPDecay(std::optional<FloatType> const& topPDecay);
static std::optional<FloatType> const& checkTemperature(std::optional<FloatType> const& temperature);
static std::optional<SizeType32> const& checkMinTokens(std::optional<SizeType32> const& minTokens);
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
std::optional<FloatType> const& beamSearchDiversityRate);
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& repetitionpenalty);
static std::optional<SizeType32> const& checkPromptIgnoreLength(
std::optional<SizeType32> const& promptIgnoreLength);
static std::optional<FloatType> const& checkLengthPenalty(std::optional<FloatType> const& lengthPenalty);
static std::optional<SizeType32> const& checkEarlyStopping(std::optional<SizeType32> const& earlyStopping);
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
static std::optional<SizeType32> const& checkNumReturnSequences(
std::optional<SizeType32> const& numReturnSequences, SizeType32 beamWidth);
static std::optional<FloatType> const& checkMinP(std::optional<FloatType> const& minP);
static std::pair<std::optional<std::vector<SizeType32>> const&, SizeType32 const> const checkBeamWidthArray(
std::optional<std::vector<SizeType32>> const& beamWidthArray, SizeType32 const beamWidth);
void updateNumReturnBeams();
friend class Serialization;
/// @brief The beam width. Default is 1 which disables beam search.
SizeType32 mBeamWidth;
/// @brief Controls number of logits to sample from. Default is 0 (all logits).
std::optional<SizeType32> mTopK;
/// @brief Controls the top-P probability to sample from. Default is 0.f
std::optional<FloatType> mTopP;
/// @brief Controls decay in the top-P algorithm. topPMin is lower-bound. Default is 1.e-6.
std::optional<FloatType> mTopPMin;
/// @brief Controls decay in the top-P algorithm. Indicates where to reset the decay. Default is 1.
std::optional<TokenIdType> mTopPResetIds;
/// @brief Controls decay in the top-P algorithm. The decay value. Default is 1.f
std::optional<FloatType> mTopPDecay;
/// @brief Controls the random seed used by the random number generator in sampling. Default is 0.
std::optional<RandomSeedType> mSeed;
/// @brief Controls the modulation of logits when sampling new tokens. It can have values > 0.f. Default is 1.0f
std::optional<FloatType> mTemperature;
/// @brief Lower bound on the number of tokens to generate. Values < 1 have no effect. Default is 1.
std::optional<SizeType32> mMinTokens;
/// @brief Controls the diversity in beam search.
std::optional<FloatType> mBeamSearchDiversityRate;
/// @brief Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f.
/// Values < 1.f encourages repetition, values > 1.f discourages it. Default is 1.f
std::optional<FloatType> mRepetitionPenalty;
/// @brief Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It
/// can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
std::optional<FloatType> mPresencePenalty;
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
std::optional<FloatType> mFrequencyPenalty;
/// @brief Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have
/// no effect. Values > input (prompt) length will be clamped. Default is 0.
std::optional<SizeType32> mPromptIgnoreLength;
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
std::optional<FloatType> mLengthPenalty;
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with
/// end_token). Default is 1.
std::optional<SizeType32> mEarlyStopping;
/// @brief Controls how many repeat ngram size are acceptable. Default is 1 << 30.
std::optional<SizeType32> mNoRepeatNgramSize;
/// @brief The number of return sequences or beams. In beam search, the value should be less than or equal to
/// mBeamWidth. In sampling, it specifies the total number of independently generated sequences.
std::optional<SizeType32> mNumReturnSequences;
/// @brief The number of beams to return. It is equal to beamWidth unless numReturnSequences is set.
/// If beamWidth > 1 and numReturnSequences is set, then numReturnBeams is equal to numReturnSequences.
SizeType32 mNumReturnBeams;
/// @brief Controls the min_p scaling for sampling.
/// It masks x which P_x < min_p * P_max, where P_x is probability of candidate x. Default is 0.f
std::optional<FloatType> mMinP;
/// @brief Controls the beam width for each step for Variable-Beam-Width-Search.
std::optional<std::vector<SizeType32>> mBeamWidthArray;
};
/// @brief Additional output that should be gathered.
/// @details By default gather output of shape [beamWidth, x] from each generation phase.
/// If gatherContext is true, also gather output of shape [promptLen, x] from context phase.
class AdditionalModelOutput
{
public:
explicit AdditionalModelOutput(std::string name, bool gatherContext = false);
bool operator==(AdditionalModelOutput const& other) const;
std::string name;
bool gatherContext{false};
};
/// @brief Configuration that controls the outputs of a Result
class OutputConfig
{
public:
explicit OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false,
bool returnGenerationLogits = false, bool excludeInputFromOutput = false, bool returnEncoderOutput = false,
bool returnPerfMetrics = false,
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs = std::nullopt);
/// @brief Controls if Result should contain log probabilities. Default is false.
bool returnLogProbs;
/// @brief Controls if Result should contain the context logits. Default is false.
bool returnContextLogits;
/// @brief Controls if Result should contain the generation logits. Default is false.
bool returnGenerationLogits;
/// @brief Controls if output tokens in Result should include the input tokens. Default is false.
bool excludeInputFromOutput;
/// @brief Controls if Result should contain encoder output hidden states (for encoder-only and encoder-decoder
/// models). Default is false.
bool returnEncoderOutput;
/// @brief Controls if Result should contain performance metrics
bool returnPerfMetrics;
/// @brief The additional outputs to gather from the model.
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs;
};
/// @brief Configuration for speculative decoding with external draft tokens.
/// Allows to include draft tokens, draft logits and specify acceptance threshold.
class ExternalDraftTokensConfig
{
public:
explicit ExternalDraftTokensConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
std::optional<FloatType> const& acceptanceThreshold = std::nullopt,
std::optional<bool> const& fastLogits = std::nullopt);
[[nodiscard]] VecTokens getTokens() const;
[[nodiscard]] std::optional<Tensor> getLogits() const;
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
[[nodiscard]] std::optional<bool> getFastLogits() const;
private:
friend class Serialization;
/// @brief The draft tokens
VecTokens mTokens;
/// @brief The draft logits. Expected shape: [num_draft_tokens, vocab_size].
std::optional<Tensor> mLogits;
/// @brief The acceptance threshold. Must be > 0.f and <= 1.f
std::optional<FloatType> mAcceptanceThreshold;
/// @brief Use direct transfer for draft logits
std::optional<bool> mFastLogits;
};
/// @brief Configuration for prompt tuning
class PromptTuningConfig
{
public:
explicit PromptTuningConfig(
Tensor embeddingTable, std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt);
[[nodiscard]] Tensor getEmbeddingTable() const;
[[nodiscard]] std::optional<VecTokenExtraIds> getInputTokenExtraIds() const;
private:
friend class Serialization;
/// @brief The prompt embedding table. Expected shape: [task vocab_size, hidden_size]. Data type must match model
/// weights.
Tensor mEmbeddingTable;
/// @brief The input token extra ids for KV Cache reuse when p-tuning is enabled
std::optional<VecTokenExtraIds> mInputTokenExtraIds;
};
/// @brief Multimodal input data class
class MultimodalInput
{
public:
explicit MultimodalInput(std::vector<std::vector<SizeType32>> multimodalHashes,
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths,
std::optional<std::vector<std::optional<std::string>>> multimodalUuids = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalItemRunCuOffsets = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalRunPositions = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalRunLengths = std::nullopt);
[[nodiscard]] std::vector<std::vector<SizeType32>> getMultimodalHashes() const;
[[nodiscard]] std::vector<SizeType32> getMultimodalPositions() const;
[[nodiscard]] std::vector<SizeType32> getMultimodalLengths() const;
[[nodiscard]] std::optional<std::vector<std::optional<std::string>>> const& getMultimodalUuids() const;
[[nodiscard]] std::optional<std::vector<SizeType32>> const& getMultimodalItemRunCuOffsets() const;
[[nodiscard]] std::optional<std::vector<SizeType32>> const& getMultimodalRunPositions() const;
[[nodiscard]] std::optional<std::vector<SizeType32>> const& getMultimodalRunLengths() const;
private:
friend class Serialization;
/// @brief The multimodal hashes
std::vector<std::vector<SizeType32>> mMultimodalHashes;
/// @brief The multimodal positions
std::vector<SizeType32> mMultimodalPositions;
/// @brief The multimodal lengths
std::vector<SizeType32> mMultimodalLengths;
/// @brief Optional user-provided UUIDs for multimodal items.
/// When provided, these are returned in KV cache events instead of content hashes.
std::optional<std::vector<std::optional<std::string>>> mMultimodalUuids;
/// @brief Optional offsets indexing the flat exact multimodal run arrays per item.
std::optional<std::vector<SizeType32>> mMultimodalItemRunCuOffsets;
/// @brief Optional prompt start positions for flat exact multimodal token runs.
std::optional<std::vector<SizeType32>> mMultimodalRunPositions;
/// @brief Optional lengths for flat exact multimodal token runs.
std::optional<std::vector<SizeType32>> mMultimodalRunLengths;
};
/// @brief Configuration for mrope
class MropeConfig
{
public:
explicit MropeConfig(Tensor mropeRoratySinCos, SizeType32 mropePositionDeltas);
[[nodiscard]] Tensor getMRopeRotaryCosSin() const;
[[nodiscard]] SizeType32 getMRopePositionDeltas() const;
private:
friend class Serialization;
/// @brief The mrope rotary sin and cos cache. Expected shape: [maxPositionEmbeddings*rotaryEmbeddingDim],Data type
/// must float32
Tensor mMRopeRotaryCosSin;
/// @brief The mrope position deltas
SizeType32 mMRopePositionDeltas;
};
/// @brief Configuration for LoRA
class LoraConfig
{
public:
explicit LoraConfig(
IdType taskId, std::optional<Tensor> weights = std::nullopt, std::optional<Tensor> config = std::nullopt);
[[nodiscard]] IdType getTaskId() const;
[[nodiscard]] std::optional<Tensor> getWeights() const;
[[nodiscard]] std::optional<Tensor> getConfig() const;
private:
friend class Serialization;
/// @brief The Lora task id
IdType mTaskId;
/// @brief The Lora weights. See TRT-LLM documentation for expected shapes and types
std::optional<Tensor> mWeights;
/// @brief The Lora configuration. See TRT-LLM documentation for detailed description of the config tensor
std::optional<Tensor> mConfig;
};
/// @brief Configuration for Look-Ahead speculative decoding.
/// Allows to include window size, ngram size and verification set size
struct LookaheadDecodingConfig
{
LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
explicit LookaheadDecodingConfig()
: LookaheadDecodingConfig(
kDefaultLookaheadDecodingWindow, kDefaultLookaheadDecodingNgram, kDefaultLookaheadDecodingVerificationSet)
{
}
bool operator==(LookaheadDecodingConfig const& other) const;
[[nodiscard]] std::tuple<SizeType32 const, SizeType32 const, SizeType32 const> get() const;
[[nodiscard]] SizeType32 getWindowSize() const;
[[nodiscard]] SizeType32 getNgramSize() const;
[[nodiscard]] SizeType32 getVerificationSetSize() const;
/// @brief return <maxDecodingTokens, maxPathLen, maxDraftTokens, maxDraftPathLen>
[[nodiscard]] std::tuple<SizeType32, SizeType32, SizeType32, SizeType32> calculateSpeculativeResource() const;
static std::tuple<SizeType32, SizeType32, SizeType32, SizeType32> calculateSpeculativeResourceTuple(
SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
/// @brief return true when `this` can be executed on resources defined by `that`
[[nodiscard]] bool isLE(LookaheadDecodingConfig const& that) const;
/// @brief return true when the parameter combination is valid.
static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
static constexpr SizeType32 kDefaultLookaheadDecodingWindow = 4;
static constexpr SizeType32 kDefaultLookaheadDecodingNgram = 3;
static constexpr SizeType32 kDefaultLookaheadDecodingVerificationSet = 4;
private:
friend class Serialization;
// Number of NGrams in lookahead branch per step.
SizeType32 mWindowSize;
// Number of tokens per NGram.
SizeType32 mNgramSize;
// Number of NGrams in verification branch per step.
SizeType32 mVerificationSetSize;
};
struct EagleConfig
{
explicit EagleConfig(std::optional<EagleChoices> eagleChoices = std::nullopt, bool greedySampling = true,
std::optional<float> posteriorThreshold = std::nullopt, bool useDynamicTree = false,
std::optional<SizeType32> dynamicTreeMaxTopK = std::nullopt);
bool operator==(EagleConfig const& other) const;
[[nodiscard]] std::optional<EagleChoices> getEagleChoices() const;
[[nodiscard]] std::optional<float> getPosteriorThreshold() const;
[[nodiscard]] bool isGreedySampling() const;
[[nodiscard]] bool useDynamicTree() const;
[[nodiscard]] std::optional<SizeType32> getDynamicTreeMaxTopK() const;
private:
std::optional<float> const& checkPosteriorValue(std::optional<float> const& value);
private:
friend class Serialization;
/// @brief choices forming tree for EAGLE-1.
std::optional<EagleChoices> mEagleChoices;
/// @brief Flag to use greedy or typical acceptance.
bool mGreedySampling;
/// @brief Minimum token probability of the typical acceptance.
/// Corresponds to epsilon in https://arxiv.org/pdf/2401.10774.
/// Default is 0.09f.
std::optional<float> mPosteriorThreshold;
/// @brief Flag to use Eagle-2
bool mUseDynamicTree;
/// @brief Number of draft tokens expand for each node in Eagle-2
std::optional<SizeType32> mDynamicTreeMaxTopK;
};
class ContextPhaseParams
{
public:
using RequestIdType = std::uint64_t;
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens,
std::optional<SizeType32> ctxDpRank = std::nullopt,
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens,
std::optional<SizeType32> ctxDpRank = std::nullopt,
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector<char> const& serializedState,
std::optional<VecTokens> draftTokens, std::optional<SizeType32> ctxDpRank = std::nullopt,
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
ContextPhaseParams(ContextPhaseParams const&);
ContextPhaseParams(ContextPhaseParams&&) noexcept;
ContextPhaseParams& operator=(ContextPhaseParams const&);
ContextPhaseParams& operator=(ContextPhaseParams&&) noexcept;
~ContextPhaseParams();
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
void setFirstGenTokens(VecTokens const& firstGenTokens) noexcept;
[[nodiscard]] std::optional<VecTokens> const& getDraftTokens() const& noexcept;
void setDraftTokens(std::optional<VecTokens> const& draftTokens) noexcept;
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
[[nodiscard]] RequestIdType getReqId() const noexcept;
void setReqId(RequestIdType const& reqId) noexcept;
[[nodiscard]] void const* getState() const noexcept;
[[nodiscard]] void* getState() noexcept;
[[nodiscard]] void* releaseState() noexcept;
[[nodiscard]] std::vector<char> getSerializedState() const noexcept;
[[nodiscard]] std::optional<SizeType32> getCtxDpRank() const noexcept;
void setCtxDpRank(std::optional<SizeType32> const& ctxDpRank) noexcept;
[[nodiscard]] std::optional<std::string> const& getDisaggInfoEndpoint() const noexcept;
void setDisaggInfoEndpoint(std::optional<std::string> const& disaggInfoEndpoint) noexcept;
private:
friend class Serialization;
static void deleter(void const* data);
using StatePtr = std::unique_ptr<void, decltype(&deleter)>;
/// @brief This request corresponds to the request ID in the context phase.
RequestIdType mReqId{0};
/// @brief The first tokens generated by context executor
VecTokens mFirstGenTokens;
/// @brief Context phase state of this request
StatePtr mState{nullptr, deleter};
/// @brief The draft tokens generated by context executor
std::optional<VecTokens> mDraftTokens;
/// @brief The context phase data parallel rank
std::optional<SizeType32> mCtxDpRank;
/// @brief The disaggregated info endpoint
std::optional<std::string> mDisaggInfoEndpoint;
};
/// @brief Configuration for speculative decoding (both draft and target models)
class SpeculativeDecodingConfig
{
public:
explicit SpeculativeDecodingConfig(bool fastLogits = false);
bool operator==(SpeculativeDecodingConfig const& other) const;
/// @brief Send logits tensor directly from draft to target model.
bool fastLogits;
private:
friend class Serialization;
};
/// @brief Guided decoding parameters for a request.
class GuidedDecodingParams
{
public:
enum class GuideType
{
/// @brief The generated text is amenable to json format.
kJSON = 0,
/// @brief The generated text is amenable to json format with additional user-specified restrictions, namely
/// schema.
kJSON_SCHEMA = 1,
/// @brief The generated text is amenable to the user-specified regular expression.
kREGEX = 2,
/// @brief The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar.
/// EBNF grammar is widely-used to express context-free grammars.
kEBNF_GRAMMAR = 3,
/// @brief The generated text is amenable to the XGrammar structural tag.
kSTRUCTURAL_TAG = 4,
};
explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt);
bool operator==(GuidedDecodingParams const& other) const;
[[nodiscard]] GuideType getGuideType() const;
[[nodiscard]] std::optional<std::string> getGuide() const;
private:
friend class Serialization;
/// @brief The guide type. See GuideType.
GuideType mGuideType;
/// @brief The detailed guide string. It could be a json schema, a regular expression or a EBNF grammar depending on
/// mGuideType.
std::optional<std::string> mGuide;
};
using RetentionPriority = SizeType32;
struct RetentionPriorityAndDuration
{
RetentionPriorityAndDuration(std::optional<RetentionPriority> const& retentionPriority,
std::optional<std::chrono::milliseconds> const& durationMs)
: retentionPriority{retentionPriority}
, durationMs{durationMs}
{
}
std::optional<RetentionPriority> retentionPriority;
std::optional<std::chrono::milliseconds> durationMs;
};
/// @brief Configuration for the request's retention in the KV Cache
class KvCacheRetentionConfig
{
public:
static constexpr RetentionPriority kMinRetentionPriority = 0;
static constexpr RetentionPriority kMaxRetentionPriority = 100;
static constexpr RetentionPriority kDefaultRetentionPriority = 35;
/// @brief A single entry to set block priorities over a token range. Earlier ranges always take priority over later
/// ones. For example, with a block size of 16, a range of [0, 17] would be applied to the first two blocks.
struct TokenRangeRetentionConfig
{
public:
explicit TokenRangeRetentionConfig(SizeType32 tokenStart, std::optional<SizeType32> tokenEnd = std::nullopt,
RetentionPriority priority = KvCacheRetentionConfig::kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
bool operator==(TokenRangeRetentionConfig const& other) const;
/// @brief The first token of this range.
SizeType32 tokenStart;
/// @brief The final token of this range. The end is not included in the range. This can be set to std::nullopt
/// to extend the range to the end of the sequence.
std::optional<SizeType32> tokenEnd;
/// @brief The priority of this token range. Higher priorities are less likely to be evicted or offloaded.
RetentionPriority priority;
/// @brief The duration in ms that the block should remain at the given priority level. Set to std::nullopt to
/// have no expiration time, and keep the block at the given priority level until it gets reclaimed. After the
/// duration has passed, the block will be moved back to the `kDefaultRetentionPriority` level.
std::optional<std::chrono::milliseconds> durationMs;
};
explicit KvCacheRetentionConfig()
: KvCacheRetentionConfig({}, kDefaultRetentionPriority)
{
}
explicit KvCacheRetentionConfig(std::vector<TokenRangeRetentionConfig> const& tokenRangeRetentionPriorities,
RetentionPriority decodeRetentionPriority = kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> decodeDurationMs = std::nullopt,
KvCacheTransferMode transferMode = KvCacheTransferMode::DRAM, std::string const& directory = "");
[[nodiscard]] std::vector<TokenRangeRetentionConfig> getTokenRangeRetentionConfigs() const;
[[nodiscard]] RetentionPriority getDecodeRetentionPriority() const;
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const;
[[nodiscard]] KvCacheTransferMode getTransferMode() const;
[[nodiscard]] std::string const& getDirectory() const;
/// @brief Convert the token range data into an entry per kv block. Returns a tuple of vectors corresponding to the
/// priorities and durations for each block.
[[nodiscard]] std::vector<RetentionPriorityAndDuration> getPerBlockRetentionPriorityDuration(
SizeType32 blockSize, SizeType32 seqLen) const;
bool operator==(KvCacheRetentionConfig const& other) const
{
return mTokenRangeRetentionConfigs == other.mTokenRangeRetentionConfigs
&& mDecodeRetentionPriority == other.mDecodeRetentionPriority
&& mDecodeDurationMs == other.mDecodeDurationMs && mTransferMode == other.mTransferMode
&& mDirectory == other.mDirectory;
}
private:
/// @brief The token ranges and priority levels to update. Ranges must be non-overlapping. For example [(0, 64),
/// (100, 128), (70, 80)] is valid, whereas
/// [(0, 64), (60, 128)] is not.
std::vector<TokenRangeRetentionConfig> mTokenRangeRetentionConfigs;
/// @brief The priority level to assign to blocks allocated in the decode phase
RetentionPriority mDecodeRetentionPriority;
/// @brief The duration in ms that decode blocks should remain at their assigned priority level.
std::optional<std::chrono::milliseconds> mDecodeDurationMs;
/// @brief The transfer mode for the block.
KvCacheTransferMode mTransferMode;
/// @brief Name of the directory if transfer mode is GDS or POSIX_DEBUG_FALLBACK.
std::string mDirectory;
};
/// @brief A class that holds information about the request
class Request
{
public:
static constexpr PriorityType kDefaultPriority = 0.5;
/// @brief The Request constructor
/// @param inputTokenIds The input token ids
/// @param maxTokens The maximum number of tokens to generate
/// @param streaming Indicates if the responses should be streamed or not. Default is false.
/// @param samplingConfig The sampling configuration
/// @param outputConfig The output configuration
/// @param endId The end token id
/// @param padId The pad token id
/// @param positionIds The input position ids
/// @param badWords A list of bad words tokens. Each "word" can be composed of multiple tokens
/// @param stopWords A list of stop words tokens. Each "word" can be composed of multiple tokens
/// @param embeddingBias The embedding bias tensor. Expected shape is [vocab_size]
/// @param externalDraftTokensConfig The speculative decoding with external draft tokens configuration
/// @param pTuningConfig The prompt tuning configuration
/// @param multimodalInput The multimodal input {multimodalHashes, multimodalPositions, multimodalLengths, optional
/// exact prompt runs}
/// @param multimodalEmbedding The multimodal embedding tensor. Expected shape is [num_multimodal_tokens,
/// hidden_dim]
/// @param mRopeConfig The mrope configuration
/// @param loraConfig The LoRA configuration
/// @param lookaheadConfig The lookahead speculative decoding configuration
/// @param kvCacheRetentionConfig The configuration used for KV cache block eviction.
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
/// name provided to the ExecutorConfig.
/// @param logitsPostProcessor The logits postprocessor dynamically specified per request; only supported with
/// replicate=false or no tensor parallelism.
/// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models
/// @param clientId
/// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens
/// after every streaming step.
/// @param priority Sets the execution priority of this request.
/// @param type Indicate the request type for disaggregated serving mode.
/// @param contextPhaseParams Generated token ID from context only executor.
/// @param encoderInputFeatures Encoder input features for multimodal models.
/// @param encoderOutputLength Encoder output length if encoder input and output have different lengths (due to
/// convolution down-sampling, etc.)
/// @param crossAttentionMask Cross attention mask.
/// @param numReturnSequences The number of returning sequences.
/// @param eagleConfig The EAGLE speculative decoding configuration
/// @param skipCrossAttnBlocks Skip the cross attention transformer blocks or not.
/// @param guidedDecodingParams The guided decoding parameters.
/// @param languageAdapterUid Task Uid for language adapter.
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
/// @param disaggRequestId Disaggregated request ID.
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
std::optional<std::list<VecTokens>> badWords = std::nullopt,
std::optional<std::list<VecTokens>> stopWords = std::nullopt,
std::optional<Tensor> embeddingBias = std::nullopt,
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig = std::nullopt,
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
std::optional<MultimodalInput> multimodalInput = std::nullopt,
std::optional<Tensor> multimodalEmbedding = std::nullopt, std::optional<MropeConfig> mRopeConfig = std::nullopt,
std::optional<LoraConfig> loraConfig = std::nullopt,
std::optional<LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
std::optional<KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
std::optional<std::string> logitsPostProcessorName = std::nullopt,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
std::optional<VecTokens> encoderInputTokenIds = std::nullopt, std::optional<IdType> clientId = std::nullopt,
bool returnAllGeneratedTokens = false, PriorityType priority = kDefaultPriority,
RequestType type = RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<ContextPhaseParams> contextPhaseParams = std::nullopt,
std::optional<Tensor> encoderInputFeatures = std::nullopt,
std::optional<SizeType32> encoderOutputLength = std::nullopt,
std::optional<Tensor> crossAttentionMask = std::nullopt, SizeType32 numReturnSequences = 1,
std::optional<EagleConfig> eagleConfig = std::nullopt, std::optional<Tensor> skipCrossAttnBlocks = std::nullopt,
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
std::optional<IdType> disaggRequestId = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
/// @brief Dynamic logits postprocessor name will be "dynamic" + requestId
static auto constexpr kDynamicPostProcessorNamePrefix = "dynamic";
Request(Request const& other);
Request(Request&& other) noexcept;
Request& operator=(Request const& other);
Request& operator=(Request&& other) noexcept;
~Request();
[[nodiscard]] VecTokens getInputTokenIds() const;
[[nodiscard]] SizeType32 getMaxTokens() const;
[[nodiscard]] bool getStreaming() const;
[[nodiscard]] SamplingConfig getSamplingConfig() const;
[[nodiscard]] OutputConfig getOutputConfig() const;
[[nodiscard]] std::optional<SizeType32> getEndId() const;
[[nodiscard]] std::optional<SizeType32> getPadId() const;
[[nodiscard]] std::optional<std::vector<SizeType32>> getPositionIds() const;
[[nodiscard]] std::optional<std::list<VecTokens>> getBadWords() const;
[[nodiscard]] std::optional<std::list<VecTokens>> getStopWords() const;
[[nodiscard]] std::optional<Tensor> getEmbeddingBias() const;
[[nodiscard]] std::optional<ExternalDraftTokensConfig> getExternalDraftTokensConfig() const;
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
[[nodiscard]] std::optional<MultimodalInput> getMultimodalInput() const;
[[nodiscard]] std::optional<Tensor> getMultimodalEmbedding() const;
[[nodiscard]] std::optional<MropeConfig> getMropeConfig() const;
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadConfig() const;
[[nodiscard]] std::optional<KvCacheRetentionConfig> getKvCacheRetentionConfig() const;
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
[[nodiscard]] std::optional<LogitsPostProcessor> getLogitsPostProcessor() const;
[[nodiscard]] std::optional<VecTokens> getEncoderInputTokenIds() const;
[[nodiscard]] std::optional<IdType> getClientId() const;
[[nodiscard]] PriorityType getPriority() const;
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
[[nodiscard]] std::optional<ContextPhaseParams> const& getContextPhaseParams() const;
[[nodiscard]] std::optional<Tensor> getEncoderInputFeatures() const;
[[nodiscard]] std::optional<SizeType32> getEncoderOutputLength() const;
[[nodiscard]] std::optional<Tensor> getCrossAttentionMask() const;
[[nodiscard]] RequestType getRequestType() const;
[[nodiscard]] std::optional<EagleConfig> getEagleConfig() const;
[[nodiscard]] std::optional<Tensor> getSkipCrossAttnBlocks() const;
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig const& config);
void setOutputConfig(OutputConfig const& outputConfig);
void setEndId(SizeType32 endId);
void setPadId(SizeType32 padId);
void setPositionIds(std::vector<SizeType32> const& positionIds);
void setBadWords(std::list<VecTokens> const& badWords);
void setStopWords(std::list<VecTokens> const& stopWords);
void setEmbeddingBias(Tensor const& embeddingBias);
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
void setMultimodalEmbedding(Tensor const& multimodalEmbedding);
void setMultimodalInput(MultimodalInput const& multimodalInput);
void setMropeConfig(MropeConfig const& mRopeConfig);
void setLoraConfig(LoraConfig const& loraConfig);
void setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig);
void setKvCacheRetentionConfig(KvCacheRetentionConfig const& kvCacheRetentionConfig);
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> const& logitsPostProcessor);
void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds);
void setClientId(IdType clientId);
void setPriority(PriorityType priority);
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
void setRequestType(RequestType const& requestType);
void setContextPhaseParams(ContextPhaseParams contextPhaseParams);
void setEncoderInputFeatures(Tensor encoderInputFeatures);
void setEncoderOutputLength(SizeType32 encoderOutputLength);
void setCrossAttentionMask(Tensor crossAttentionMask);
void setEagleConfig(std::optional<EagleConfig> const& eagleConfig);
void setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks);
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
void setCacheSaltID(CacheSaltIDType cacheSaltID);
void setDisaggRequestId(IdType disaggRequestId);
private:
friend class Serialization;
class Impl;
std::unique_ptr<Impl> mImpl;
};
/// @brief Struct that holds the logits information when using direct transfer
struct SpeculativeDecodingFastLogitsInfo
{
/// @brief Draft request id
uint64_t draftRequestId;
/// @brief MPI world rank of the draft model leader
int32_t draftParticipantId;
/// @brief Returns the struct serialized into a tensor that can be used as generation logits input
[[nodiscard]] Tensor toTensor() const;
};
struct AdditionalOutput
{
AdditionalOutput(std::string name, Tensor output)
: name(std::move(name))
, output(std::move(output))
{
}
AdditionalOutput(AdditionalOutput const& other) = default;
AdditionalOutput(AdditionalOutput&& other) noexcept = default;
AdditionalOutput& operator=(AdditionalOutput const& other) = default;
AdditionalOutput& operator=(AdditionalOutput&& other) noexcept = default;
~AdditionalOutput() = default;
std::string name;
Tensor output;
};
/// @brief Struct that holds the generation result
struct Result
{
/// @brief Indicates if this is the final result for the request
bool isFinal;
/// @brief The output tokens for each beam
BeamTokens outputTokenIds;
/// @brief The cumulative log probabilities. Size beamSize.
std::optional<VecLogProbs> cumLogProbs;
/// @brief The log probabilities for each generated token. Size [beamSize, outputLen]
std::optional<std::vector<VecLogProbs>> logProbs;
/// @brief The context logits. Size [promptLen, vocabSizePadded]
std::optional<Tensor> contextLogits;
/// @brief The generation logits. Size [beamSize, maxTokens, vocabSizePadded] (non-streaming)
/// or [maxTokens, beamSize, vocabSizePadded] (streaming and allGeneratedTokens)
/// or [1, beamSize, vocabSizePadded] (streaming and non-allGeneratedTokens)
std::optional<Tensor> generationLogits;
/// @brief Logits information for direct transfer when using fast logits
std::optional<SpeculativeDecodingFastLogitsInfo> specDecFastLogitsInfo;
/// @brief The encoder output. Size [encoderLen, hiddenSize]
std::optional<Tensor> encoderOutput;
/// @brief The reason why the model stopped generating tokens for each beam in this request. Size [beamSize].
/// Currently only supported when beamSize is 1 and when using BatchingType::kINFLIGHT.
std::vector<FinishReason> finishReasons;
/// @brief The params of the context phase.
std::optional<ContextPhaseParams> contextPhaseParams;
/// @brief The number of the decoding iterations used to generate the result.
/// In autoregressive decoding, it is equal to the maximum length of the beam in outputTokenIds.
/// In speculative decoding, might be less than maximum length of the beam in outputTokenIds as more than
/// one token can be generated per iteration. Used for speculative decoding statistics.
SizeType32 decodingIter{0};
/// @brief The average number of decoded tokens per iteration. For standard model it is 1.
/// For speculative decoding model >= 1 -- number of draft tokens accepted per step + 1.
float avgDecodedTokensPerIter{0.0f};
/// @brief The index of the output sequence of this result where 0 <= sequenceIndex < numReturnSequences.
/// In beam search (beamWidth > 1), this index will be always zero because all beams to be returned are included
/// in this result.
SizeType32 sequenceIndex{0};
/// @brief Indicates if this is the final result for a given sequence in the request
/// In beam search (beamWidth > 1), the value will always equal to the value of isFinal.
bool isSequenceFinal;
/// @brief Performance metrics if returnPerfMetrics is set in OutputConfig
std::optional<RequestPerfMetrics> requestPerfMetrics;
/// @brief The additional outputs
std::vector<AdditionalOutput> additionalOutputs;
};
/// @brief Class that holds either an error or a result
class Response
{
public:
Response(IdType requestId, std::string errorMsg, std::optional<IdType> clientId = std::nullopt);
Response(IdType requestId, Result Result, std::optional<IdType> clientId = std::nullopt);
~Response();
Response(Response const& other);
Response(Response&& other) noexcept;
Response& operator=(Response const& other);
Response& operator=(Response&& other) noexcept;
/// @brief Get the id of the request for which this response was generated
[[nodiscard]] IdType getRequestId() const;
/// @brief Get the client id of the request for which this response was generated
[[nodiscard]] std::optional<IdType> getClientId() const;
/// @brief Indicates if this response has an error or not
[[nodiscard]] bool hasError() const;
/// @brief Get the error msg for this response
/// Will throw an exception if hasError is false
[[nodiscard]] std::string const& getErrorMsg() const;
/// @brief Get the result for this response
/// Will throw an exception if hasResult is true
[[nodiscard]] Result const& getResult() const;
private:
friend class Serialization;
class Impl;
std::unique_ptr<Impl> mImpl;
};
/// @brief Configuration class for dynamic tuning of batch size and max num tokens. During runtime the statistics of
/// input and output lengths are recoreded. Based on these statistics, the batch size and max num tokens are tuned
/// dynamically to better serve the requests.
class DynamicBatchConfig
{
public:
/// @brief The default window size for moving average of input and output length which is used to calculate dynamic
/// batch size and max num tokens
static SizeType32 const kDefaultDynamicBatchMovingAverageWindow = 128;
explicit DynamicBatchConfig(bool enableBatchSizeTuning = false, bool enableMaxNumTokensTuning = false,
SizeType32 dynamicBatchMovingAverageWindow = kDefaultDynamicBatchMovingAverageWindow,
std::vector<std::pair<SizeType32, SizeType32>> batchSizeTable = kDefaultBatchSizeTable);
[[nodiscard]] SizeType32 getDynamicBatchMovingAverageWindow() const;
[[nodiscard]] bool getEnableBatchSizeTuning() const;
[[nodiscard]] bool getEnableMaxNumTokensTuning() const;
[[nodiscard]] std::vector<std::pair<SizeType32, SizeType32>> getBatchSizeTable() const;
/// @brief The default value of batch size table
static std::vector<std::pair<SizeType32, SizeType32>> const kDefaultBatchSizeTable;
private:
friend class Serialization;
/// @brief Controls if the batch size should be tuned dynamically
bool mEnableBatchSizeTuning;