Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7837c3f

Browse files
committed
Fix return types and import comments
1 parent 55d6308 commit 7837c3f

File tree

1 file changed

+38
-34
lines changed

1 file changed

+38
-34
lines changed

llama_cpp/llama_cpp.py

+38-34
Original file line numberDiff line numberDiff line change
@@ -427,13 +427,16 @@ def llama_token_nl() -> llama_token:
427427

428428

429429
# Sampling functions
430+
431+
432+
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
430433
def llama_sample_repetition_penalty(
431434
ctx: llama_context_p,
432435
candidates,
433436
last_tokens_data,
434437
last_tokens_size: c_int,
435438
penalty: c_float,
436-
) -> llama_token:
439+
):
437440
return _lib.llama_sample_repetition_penalty(
438441
ctx, candidates, last_tokens_data, last_tokens_size, penalty
439442
)
@@ -446,18 +449,18 @@ def llama_sample_repetition_penalty(
446449
c_int,
447450
c_float,
448451
]
449-
_lib.llama_sample_repetition_penalty.restype = llama_token
452+
_lib.llama_sample_repetition_penalty.restype = None
450453

451454

452-
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
455+
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
453456
def llama_sample_frequency_and_presence_penalties(
454457
ctx: llama_context_p,
455458
candidates,
456459
last_tokens_data,
457460
last_tokens_size: c_int,
458461
alpha_frequency: c_float,
459462
alpha_presence: c_float,
460-
) -> llama_token:
463+
):
461464
return _lib.llama_sample_frequency_and_presence_penalties(
462465
ctx,
463466
candidates,
@@ -476,25 +479,23 @@ def llama_sample_frequency_and_presence_penalties(
476479
c_float,
477480
c_float,
478481
]
479-
_lib.llama_sample_frequency_and_presence_penalties.restype = llama_token
482+
_lib.llama_sample_frequency_and_presence_penalties.restype = None
480483

481484

482-
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
483-
def llama_sample_softmax(ctx: llama_context_p, candidates) -> llama_token:
485+
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
486+
def llama_sample_softmax(ctx: llama_context_p, candidates):
484487
return _lib.llama_sample_softmax(ctx, candidates)
485488

486489

487490
_lib.llama_sample_softmax.argtypes = [
488491
llama_context_p,
489492
llama_token_data_array_p,
490493
]
491-
_lib.llama_sample_softmax.restype = llama_token
494+
_lib.llama_sample_softmax.restype = None
492495

493496

494-
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
495-
def llama_sample_top_k(
496-
ctx: llama_context_p, candidates, k: c_int, min_keep: c_int
497-
) -> llama_token:
497+
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
498+
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
498499
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
499500

500501

@@ -504,12 +505,11 @@ def llama_sample_top_k(
504505
c_int,
505506
c_int,
506507
]
508+
_lib.llama_sample_top_k.restype = None
507509

508510

509-
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
510-
def llama_sample_top_p(
511-
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
512-
) -> llama_token:
511+
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
512+
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
513513
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
514514

515515

@@ -519,13 +519,13 @@ def llama_sample_top_p(
519519
c_float,
520520
c_int,
521521
]
522-
_lib.llama_sample_top_p.restype = llama_token
522+
_lib.llama_sample_top_p.restype = None
523523

524524

525-
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
525+
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
526526
def llama_sample_tail_free(
527527
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
528-
) -> llama_token:
528+
):
529529
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
530530

531531

@@ -535,13 +535,11 @@ def llama_sample_tail_free(
535535
c_float,
536536
c_int,
537537
]
538-
_lib.llama_sample_tail_free.restype = llama_token
538+
_lib.llama_sample_tail_free.restype = None
539539

540540

541-
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
542-
def llama_sample_typical(
543-
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
544-
) -> llama_token:
541+
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
542+
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
545543
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
546544

547545

@@ -551,13 +549,10 @@ def llama_sample_typical(
551549
c_float,
552550
c_int,
553551
]
554-
_lib.llama_sample_typical.restype = llama_token
552+
_lib.llama_sample_typical.restype = None
555553

556554

557-
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
558-
def llama_sample_temperature(
559-
ctx: llama_context_p, candidates, temp: c_float
560-
) -> llama_token:
555+
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
561556
return _lib.llama_sample_temperature(ctx, candidates, temp)
562557

563558

@@ -566,10 +561,15 @@ def llama_sample_temperature(
566561
llama_token_data_array_p,
567562
c_float,
568563
]
569-
_lib.llama_sample_temperature.restype = llama_token
564+
_lib.llama_sample_temperature.restype = None
570565

571566

572-
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
567+
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
568+
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
569+
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
570+
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
571+
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
572+
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
573573
def llama_sample_token_mirostat(
574574
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
575575
) -> llama_token:
@@ -587,7 +587,11 @@ def llama_sample_token_mirostat(
587587
_lib.llama_sample_token_mirostat.restype = llama_token
588588

589589

590-
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
590+
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
591+
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
592+
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
593+
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
594+
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
591595
def llama_sample_token_mirostat_v2(
592596
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
593597
) -> llama_token:
@@ -604,7 +608,7 @@ def llama_sample_token_mirostat_v2(
604608
_lib.llama_sample_token_mirostat_v2.restype = llama_token
605609

606610

607-
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
611+
# @details Selects the token with the highest probability.
608612
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
609613
return _lib.llama_sample_token_greedy(ctx, candidates)
610614

@@ -616,7 +620,7 @@ def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
616620
_lib.llama_sample_token_greedy.restype = llama_token
617621

618622

619-
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
623+
# @details Randomly selects a token from the candidates based on their probabilities.
620624
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
621625
return _lib.llama_sample_token(ctx, candidates)
622626

0 commit comments

Comments
 (0)