Thanks to visit codestin.com
Credit goes to github.com

Skip to content

llama: implement YaRN RoPE scaling#2268

Merged
cebtenzzre merged 36 commits into
ggml-org:masterfrom
cebtenzzre:ntkv2
Nov 1, 2023
Merged

llama: implement YaRN RoPE scaling#2268
cebtenzzre merged 36 commits into
ggml-org:masterfrom
cebtenzzre:ntkv2

Conversation

@cebtenzzre

@cebtenzzre cebtenzzre commented Jul 18, 2023

Copy link
Copy Markdown
Collaborator

This is an implementation of YaRN RoPE scaling. See https://github.com/jquesnelle/yarn and the paper and errata.

TODO:

  • Add new GGUF key for how much context the base model was trained on
  • Support converting the new models to GGUF
  • Add backward implementations
  • Test new LLaMA implementation
  • Finish and test Falcon implementation

@cebtenzzre cebtenzzre force-pushed the ntkv2 branch 3 times, most recently from ce59171 to f3b9eae Compare July 19, 2023 03:55
@cebtenzzre cebtenzzre changed the title llama: implement NTK-By-Parts (NTKv2) llama: implement NTK-By-Parts (NTKv2) RoPE scaling Jul 19, 2023
@FNsi

FNsi commented Jul 20, 2023

Copy link
Copy Markdown
Contributor

Any guide to set para extrapolation and ntk? How do they work with previous two paras?

@cebtenzzre

Copy link
Copy Markdown
Collaborator Author

The upstream NTKv2 doesn't use --rope-freq-base, so it probably doesn't make sense to use it. It does use --rope-freq-scale, which works like linear scaling, and is supposed to be calibrated so that e.g. .25 scale actually gives you 8192 context. To use the default NTKv2, you should set --rope-ntk-factor and --rope-extrapolation-factor to 1, and set --rope-freq-scale appropriately. The lower the factors are, the less the respective scaling methods are mixed in, although I believe the graphs have been generated with both at 100% - the code automatically ramps them based on some experimentally determined thresholds.

@cebtenzzre cebtenzzre marked this pull request as ready for review July 21, 2023 22:04
@cebtenzzre

Copy link
Copy Markdown
Collaborator Author

I would appreciate help with the following:

  • Should I try to write a backwards implementation? NTKv1 still doesn't have one, so I don't have much to base it on.
  • I don't have a Mac to test the Metal code on. If anyone sees obvious flaws or can test it locally, let me know.
  • I'm going to try to run a perplexity benchmark against NTKv1 and linear scaling, but I don't know if my current hardware is up to the task.

@ggerganov ggerganov left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename everywhere extrapolation_factor to ext_factor

Comment thread ggml.c Outdated
@ggerganov

Copy link
Copy Markdown
Member

No need for backwards implementation for now

@cebtenzzre

This comment was marked as outdated.

@cebtenzzre

Copy link
Copy Markdown
Collaborator Author

Perplexity with NTKv2 may be worse because neither is the dynamic version, which AFAIK works better on non-finetuned models. But fine-tuned models are far superior anyway.

NTKv1 does not converge when fine-tuning, which is why NTKv2 exists. So until somebody publishes a model fine-tuned with NTKv2—maybe LLongMAv2 will be released after jquesnelle publishes the paper based on scaled-rope—the existing LLongMA, which uses regular linear interpolation (just like SuperHOT), is the state-of-the-art for long contexts.

@cebtenzzre

cebtenzzre commented Aug 31, 2023

Copy link
Copy Markdown
Collaborator Author

The paper has been released. The resulting method is called YaRN. Apparently the models that use this technique are good to about 120k tokens of context.
Screenshot from 2023-08-31 16-53-18

More work will definitely be needed to use these models with llama.cpp.

@cebtenzzre cebtenzzre changed the title llama: implement NTK-By-Parts (NTKv2) RoPE scaling llama: implement YaRN RoPE scaling Sep 5, 2023
@cebtenzzre

This comment was marked as resolved.

@bloc97

bloc97 commented Sep 6, 2023

Copy link
Copy Markdown

Thank you for the llamacpp implementation of YaRN!

I'm just letting you know that

constant float max_pos_emb = 2048;

should be changed to 4096 for llama 2 models when using YaRN (default was 2048 because we did the most tests with llama 1 models)
This value should probably be saved inside of the model configs and be loaded on inference...

@cebtenzzre

Copy link
Copy Markdown
Collaborator Author

should be changed to 4096 for llama 2 models

Thanks for reminding me. I originally made this PR before GGUF was finished, so I hardcoded it in the meantime. I believe I can now use the value of llama.context_length for this purpose.

@KerfuffleV2

Copy link
Copy Markdown
Contributor

Would it be worth testing this with non-YaRN fine-tuned models? If so, any suggested settings? I can test it with ROCM.

@Green-Sky

Green-Sky commented Sep 6, 2023

Copy link
Copy Markdown
Collaborator

Thank you for the llamacpp implementation of YaRN!

I'm just letting you know that

constant float max_pos_emb = 2048;

should be changed to 4096 for llama 2 models when using YaRN (default was 2048 because we did the most tests with llama 1 models) This value should probably be saved inside of the model configs and be loaded on inference...

this needs to be a new GGUF kv, something like "rope_yarn_orig_ctx"

Thanks for reminding me. I originally made this PR before GGUF was finished, so I hardcoded it in the meantime. I believe I can now use the value of llama.context_length for this purpose.

llama.context_length should be the size of the finetune. eg 128Ki

@cebtenzzre cebtenzzre marked this pull request as draft September 6, 2023 15:50
@bloc97

bloc97 commented Sep 6, 2023

Copy link
Copy Markdown

this needs to be a new GGUF kv, something like "rope_yarn_orig_ctx"

Exactly, after finetuning a model with YaRN, we have to keep track of two values, one being the original context length (2048 for LLaMA or 4096 for Llama 2), and also the final context length (which can be calculated by multipling the original ctx length by the scale factor, eg. 4096 x 32 = 128Ki)

In this case, the constant constant float max_pos_emb = 2048; used in the equations must be equal to the original context size, not the final context size.

Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
Co-authored-by: cebtenzzre <[email protected]>
Co-authored-by: Jeffrey Quesnelle <[email protected]>
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* fix backward process of rope

rope backward process was broken after YaRN RoPE (ggml-org#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
Co-authored-by: cebtenzzre <[email protected]>
Co-authored-by: Jeffrey Quesnelle <[email protected]>
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
* fix backward process of rope

rope backward process was broken after YaRN RoPE (ggml-org#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
Co-authored-by: cebtenzzre <[email protected]>
Co-authored-by: Jeffrey Quesnelle <[email protected]>
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
* fix backward process of rope

rope backward process was broken after YaRN RoPE (ggml-org#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
Co-authored-by: cebtenzzre <[email protected]>
Co-authored-by: Jeffrey Quesnelle <[email protected]>
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
* fix backward process of rope

rope backward process was broken after YaRN RoPE (ggml-org#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
Co-authored-by: cebtenzzre <[email protected]>
Co-authored-by: Jeffrey Quesnelle <[email protected]>
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
* fix backward process of rope

rope backward process was broken after YaRN RoPE (ggml-org#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
Co-authored-by: cebtenzzre <[email protected]>
Co-authored-by: Jeffrey Quesnelle <[email protected]>
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
* fix backward process of rope

rope backward process was broken after YaRN RoPE (ggml-org#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.