Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 70bc0b8

Browse files
committed
Fix a bug in the rope calculation
1 parent 18ebda3 commit 70bc0b8

File tree

4 files changed

+92
-6
lines changed

4 files changed

+92
-6
lines changed

convert-pth-to-ggml.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
fout.write(struct.pack("i", hparams["multiple_of"]))
7474
fout.write(struct.pack("i", hparams["n_heads"]))
7575
fout.write(struct.pack("i", hparams["n_layers"]))
76-
fout.write(struct.pack("i", 64)) # rot
76+
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
7777
fout.write(struct.pack("i", ftype))
7878

7979
# Is this correct??

main.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ bool llama_eval(
400400
const int n_ctx = hparams.n_ctx;
401401
const int n_head = hparams.n_head;
402402
const int n_vocab = hparams.n_vocab;
403-
const int n_rot = hparams.n_rot;
403+
const int n_rot = hparams.n_embd/hparams.n_head;
404404

405405
const int d_key = n_embd/n_head;
406406

@@ -628,6 +628,9 @@ int main(int argc, char ** argv) {
628628
params.prompt = gpt_random_prompt(rng);
629629
}
630630

631+
// params.prompt = R"(// this function checks if the number n is prime
632+
//bool is_prime(int n) {)";
633+
631634
int64_t t_load_us = 0;
632635

633636
gpt_vocab vocab;
@@ -691,7 +694,6 @@ int main(int argc, char ** argv) {
691694

692695
if (i >= embd_inp.size()) {
693696
// sample next token
694-
const int top_k = params.top_k;
695697
const float top_p = params.top_p;
696698
const float temp = params.temp;
697699

@@ -702,7 +704,7 @@ int main(int argc, char ** argv) {
702704
{
703705
const int64_t t_start_sample_us = ggml_time_us();
704706

705-
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
707+
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
706708

707709
t_sample_us += ggml_time_us() - t_start_sample_us;
708710
}

utils.cpp

+78-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
257257
}
258258
}
259259

260-
if (l == 0 && t != 13) {
260+
if (l == 0) {
261261
break;
262262
}
263263

@@ -367,6 +367,83 @@ gpt_vocab::id gpt_sample_top_k_top_p(
367367
return logits_id[idx].second;
368368
}
369369

370+
gpt_vocab::id llama_sample_top_p(
371+
const gpt_vocab & vocab,
372+
const float * logits,
373+
double top_p,
374+
double temp,
375+
std::mt19937 & rng) {
376+
int n_logits = vocab.id_to_token.size();
377+
378+
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
379+
logits_id.reserve(n_logits);
380+
381+
{
382+
const double scale = 1.0/temp;
383+
for (int i = 0; i < n_logits; ++i) {
384+
logits_id.push_back(std::make_pair(logits[i]*scale, i));
385+
}
386+
}
387+
388+
std::sort(
389+
logits_id.begin(),
390+
logits_id.end(),
391+
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
392+
return a.first > b.first;
393+
});
394+
395+
double maxl = -INFINITY;
396+
for (const auto & kv : logits_id) {
397+
maxl = std::max(maxl, kv.first);
398+
}
399+
400+
// compute probs for the top K tokens
401+
std::vector<double> probs;
402+
probs.reserve(logits_id.size());
403+
404+
double sum = 0.0;
405+
for (const auto & kv : logits_id) {
406+
double p = exp(kv.first - maxl);
407+
probs.push_back(p);
408+
sum += p;
409+
}
410+
411+
// normalize the probs
412+
for (auto & p : probs) {
413+
p /= sum;
414+
}
415+
416+
if (top_p < 1.0f) {
417+
double cumsum = 0.0f;
418+
for (int i = 0; i < (int) probs.size(); i++) {
419+
cumsum += probs[i];
420+
if (cumsum >= top_p) {
421+
probs.resize(i + 1);
422+
logits_id.resize(i + 1);
423+
break;
424+
}
425+
}
426+
427+
cumsum = 1.0/cumsum;
428+
for (int i = 0; i < (int) probs.size(); i++) {
429+
probs[i] *= cumsum;
430+
}
431+
}
432+
433+
//printf("\n");
434+
//for (int i = 0; i < (int) 10; i++) {
435+
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
436+
//}
437+
//printf("\n\n");
438+
//exit(0);
439+
440+
std::discrete_distribution<> dist(probs.begin(), probs.end());
441+
int idx = dist(rng);
442+
443+
return logits_id[idx].second;
444+
}
445+
446+
370447
size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
371448
const int nb = k / qk;
372449
const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2);

utils.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct gpt_params {
1818
int32_t n_predict = 128; // new tokens to predict
1919

2020
// sampling parameters
21-
int32_t top_k = 40;
21+
int32_t top_k = 40; // unused
2222
float top_p = 0.95f;
2323
float temp = 0.80f;
2424

@@ -86,6 +86,13 @@ gpt_vocab::id gpt_sample_top_k_top_p(
8686
double temp,
8787
std::mt19937 & rng);
8888

89+
gpt_vocab::id llama_sample_top_p(
90+
const gpt_vocab & vocab,
91+
const float * logits,
92+
double top_p,
93+
double temp,
94+
std::mt19937 & rng);
95+
8996
//
9097
// Quantization
9198
//

0 commit comments

Comments
 (0)