@@ -113,17 +113,51 @@ struct common_sampler {
113113 llama_token_data_array cur_p;
114114
115115 void set_logits (struct llama_context * ctx, int idx) {
116- const auto * logits = llama_get_logits_ith (ctx, idx);
116+ const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
117+ const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
118+ const llama_token * sampled_ids = llama_get_sampled_token_ids_ith (ctx, idx);
117119
118120 const llama_model * model = llama_get_model (ctx);
119121 const llama_vocab * vocab = llama_model_get_vocab (model);
120122
121123 const int n_vocab = llama_vocab_n_tokens (vocab);
122124
123- cur.resize (n_vocab);
125+ // Use the member variable instead of allocating locally
126+ cur.clear ();
124127
125- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
126- cur[token_id] = llama_token_data{token_id, logits[token_id], 0 .0f };
128+ if (sampled_probs) {
129+ const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith (ctx, idx);
130+ cur.reserve (sampled_probs_count);
131+ // The GPU sampler has filtered the probabilities so we need to use the sampled ids.
132+ if (sampled_ids != nullptr ) {
133+ for (uint32_t i = 0 ; i < sampled_probs_count; ++i) {
134+ cur.emplace_back (llama_token_data{sampled_ids[i], 0 .0f , sampled_probs[i]});
135+ }
136+ } else {
137+ for (llama_token token_id = 0 ; token_id < (int ) sampled_probs_count; token_id++) {
138+ cur.emplace_back (llama_token_data{token_id, 0 .0f , sampled_probs[token_id]});
139+ }
140+ }
141+ } else if (sampled_logits) {
142+ const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith (ctx, idx);
143+ cur.reserve (sampled_logits_count);
144+ // The GPU sampler has filtered the logits so we need to use the sampled ids.
145+ if (sampled_ids != nullptr ) {
146+ for (llama_token i = 0 ; i < (int )sampled_logits_count; i++) {
147+ cur.emplace_back (llama_token_data{sampled_ids[i], sampled_logits[i], 0 .0f });
148+ }
149+ } else {
150+ for (llama_token token_id = 0 ; token_id < (int )sampled_logits_count; token_id++) {
151+ cur.emplace_back (llama_token_data{token_id, sampled_logits[token_id], 0 .0f });
152+ }
153+ }
154+ } else {
155+ const auto * logits = llama_get_logits_ith (ctx, idx);
156+ GGML_ASSERT (logits != nullptr );
157+ cur.reserve (n_vocab);
158+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
159+ cur.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
160+ }
127161 }
128162
129163 cur_p = { cur.data (), cur.size (), -1 , false };
@@ -287,6 +321,42 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
287321 return result;
288322}
289323
324+ struct llama_sampler * common_sampler_gpu_init (const struct llama_model * model, const struct common_params_sampling & params) {
325+ GGML_UNUSED (model);
326+
327+ llama_sampler_chain_params chain_params = llama_sampler_chain_default_params ();
328+ chain_params.no_perf = params.no_perf ;
329+
330+ struct llama_sampler * chain = llama_sampler_chain_init (chain_params);
331+
332+ if (!params.gpu_sampling ) {
333+ return chain; // return empty chain
334+ }
335+
336+ if (params.gpu_temp > 0 .0f ) {
337+ llama_sampler_chain_add (chain, llama_sampler_gpu_init_temp (params.gpu_temp ));
338+ }
339+
340+ if (params.gpu_top_k > 0 ) {
341+ llama_sampler_chain_add (chain, llama_sampler_gpu_init_top_k (params.gpu_top_k ));
342+ }
343+
344+ // TODO: GPU top_p is an approximation using top_k at the moment
345+ if (params.gpu_top_p_approx_k > 0 ) {
346+ llama_sampler_chain_add (chain, llama_sampler_gpu_init_top_p (params.gpu_top_p_approx_k ));
347+ }
348+
349+ if (params.gpu_softmax ) {
350+ llama_sampler_chain_add (chain, llama_sampler_gpu_init_softmax ());
351+ }
352+
353+ if (params.gpu_dist ) {
354+ llama_sampler_chain_add (chain, llama_sampler_gpu_init_dist (params.seed ));
355+ }
356+
357+ return chain;
358+ }
359+
290360void common_sampler_free (struct common_sampler * gsmpl) {
291361 if (gsmpl) {
292362 llama_sampler_free (gsmpl->grmr );
@@ -337,6 +407,13 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
337407}
338408
339409llama_token common_sampler_sample (struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
410+ // Check if a GPU sampler has already sampled a token in which case we
411+ // return that token id directly.
412+ const llama_token gpu_sampled_token = llama_get_sampled_token_ith (ctx, idx);
413+ if (gpu_sampled_token != LLAMA_TOKEN_NULL) {
414+ return gpu_sampled_token;
415+ }
416+
340417 gsmpl->set_logits (ctx, idx);
341418
342419 auto & grmr = gsmpl->grmr ;
0 commit comments