Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@yuanyao-nv
Copy link
Contributor

@yuanyao-nv yuanyao-nv commented Jul 23, 2025

Description

To accompany the TensorScatter-24 op for managing in-place KV cache update, this PR makes the following changes to the Attention op:

  • Add nonpad_kv_seqlen to indicate the number of valid (nonpadded) tokens in the K and V inputs when the K and V inputs are the entire cache tensors (where the number of valid tokens can potentially make up only a small proportion of the cache tensors). The nonpad_kv_seqlen input would provided optimization opportunities for backends to skip the unnecessary computation on the padding tokens.
  • Allow the kv_seqlen dimension (-1 dimension) of attn_mask input to be shorter than K and V. The missing portion will be assumed to be -inf. The length should still be larger than the max value in nonpad_kv_seqlen.

Also, allow attn_mask and is_causal to be present at the same time. This would allow for easier export of HF models later.

Motivation and Context

@yuanyao-nv yuanyao-nv requested a review from a team as a code owner July 23, 2025 23:34
@github-project-automation github-project-automation bot moved this to In progress in PR Tracker Jul 23, 2025
@yuanyao-nv yuanyao-nv marked this pull request as draft July 23, 2025 23:35
@codecov
Copy link

codecov bot commented Jul 23, 2025

Codecov Report

❌ Patch coverage is 54.54545% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 53.76%. Comparing base (ee724f6) to head (f92c9d9).
⚠️ Report is 77 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnx/backend/test/case/node/attention.py 0.00% 10 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #7164   +/-   ##
=======================================
  Coverage   53.76%   53.76%           
=======================================
  Files         512      512           
  Lines       32180    32202   +22     
  Branches     2942     2945    +3     
=======================================
+ Hits        17300    17312   +12     
- Misses      14110    14120   +10     
  Partials      770      770           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +4251 to +4252
ONNX_OPERATOR_SET_SCHEMA(
Attention,

Check notice

Code scanning / CodeQL

Unused static variable Note

Static variable dbg_count_check_Onnx_23_verAttention is never read.
@yuanyao-nv yuanyao-nv force-pushed the dev-attention-seqlen branch from 9c714fb to 326aefb Compare July 25, 2025 00:16
@justinchuby justinchuby added this to the 1.19 milestone Jul 25, 2025
Comment on lines +25 to +31
ONNX_ASSERTM(
false,
"%s being converted from %d to %d has nonpad_kv_seqlen input, "
"which is not supported in opset 23. This conversion cannot be performed.",
name().c_str(),
initial_version().version(),
target_version().version());

Check notice

Code scanning / CodeQL

Too many arguments to formatting function Note

Format for barf (in a macro expansion) expects 7 arguments but given 8
Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks ... just a couple of minor comments left about documentation of attn_sequence_length

@github-project-automation github-project-automation bot moved this from In progress to Reviewer approved in PR Tracker Jul 28, 2025
@yuanyao-nv yuanyao-nv marked this pull request as ready for review July 29, 2025 20:25
@yuanyao-nv yuanyao-nv requested a review from a team as a code owner July 29, 2025 20:25
@justinchuby justinchuby requested a review from Copilot July 29, 2025 21:45
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new nonpad_kv_seqlen input to the Attention operator in version 24 to support optimized KV cache management. This enhancement accompanies the TensorScatter-24 operator for managing in-place KV cache updates.

Key changes include:

  • Addition of nonpad_kv_seqlen input to indicate valid (non-padded) tokens in K and V inputs
  • Support for shorter attn_mask dimensions that get padded with -inf
  • Compatibility between attn_mask and is_causal attributes

Reviewed Changes

Copilot reviewed 10 out of 146 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnx/version_converter/convert.h Registers adapters for converting between Attention v24 and v23
onnx/version_converter/adapters/Attention_24_23.h Implements downgrade adapter that prevents conversion when nonpad_kv_seqlen is present
onnx/reference/ops/op_attention.py Updates reference implementation to handle nonpad_kv_seqlen and shorter attn_mask
onnx/defs/operator_sets.h Adds Attention-24 to the operator set schema declarations
onnx/defs/nn/old.cc Moves Attention-23 schema to old.cc for version history
onnx/defs/nn/defs.cc Implements Attention-24 with updated documentation and function builder
onnx/backend/test/case/node/attention.py Adds test case for the new nonpad_kv_seqlen functionality
docs/TestCoverage.md Updates test coverage documentation
docs/Operators.md Updates operator documentation for Attention-24
docs/Changelog.md Adds changelog entry for Attention-24
Comments suppressed due to low confidence (2)

onnx/backend/test/case/node/attention.py:1859

  • The test uses a fixed nonpad_kv_seqlen array with values [3, 4], but the K and V tensors have sequence length 6. Consider adding test cases that cover edge cases like when nonpad_kv_seqlen equals the full sequence length, or when it's 0 or 1.
        nonpad_kv_seqlen = np.array([3, 4], dtype=np.int64)

onnx/backend/test/case/node/attention.py:1858

  • The test creates an attention mask with kv_sequence_length=4, but K and V have sequence length 6. This tests the padding functionality, but consider adding a test comment explaining this intentional dimension mismatch to clarify the test's purpose.
        attn_mask = np.random.rand(2, 3, 4, 4).astype(np.float32)

Signed-off-by: Yuan Yao <[email protected]>
@yuanyao-nv yuanyao-nv merged commit 13b6330 into onnx:main Jul 30, 2025
37 of 38 checks passed
@github-project-automation github-project-automation bot moved this from Reviewer approved to Done in PR Tracker Jul 30, 2025
alx256 pushed a commit to alx256/onnx that referenced this pull request Aug 1, 2025
To accompany the
[TensorScatter-24](onnx#7114) op for
managing in-place KV cache update, this PR makes the following changes
to the Attention op:
- Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded)
tokens in the K and V inputs when the K and V inputs are the entire
cache tensors (where the number of valid tokens can potentially make up
only a small proportion of the cache tensors). The `nonpad_kv_seqlen`
input would provided optimization opportunities for backends to skip the
unnecessary computation on the padding tokens.
- Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to
be shorter than K and V. The missing portion will be assumed to be -inf.
The length should still be larger than the max value in
`nonpad_kv_seqlen`.

Also, allow `attn_mask` and `is_causal` to be present at the same time.
This would allow for easier export of HF models later.

<!-- - Why is this change required? What problem does it solve? -->
<!-- - If it fixes an open issue, please link to the issue here. -->

---------

Signed-off-by: Yuan Yao <[email protected]>
MagellaX pushed a commit to MagellaX/onnx that referenced this pull request Aug 9, 2025
### Description

To accompany the
[TensorScatter-24](onnx#7114) op for
managing in-place KV cache update, this PR makes the following changes
to the Attention op:
- Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded)
tokens in the K and V inputs when the K and V inputs are the entire
cache tensors (where the number of valid tokens can potentially make up
only a small proportion of the cache tensors). The `nonpad_kv_seqlen`
input would provided optimization opportunities for backends to skip the
unnecessary computation on the padding tokens.
- Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to
be shorter than K and V. The missing portion will be assumed to be -inf.
The length should still be larger than the max value in
`nonpad_kv_seqlen`.

Also, allow `attn_mask` and `is_causal` to be present at the same time.
This would allow for easier export of HF models later.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve? -->
<!-- - If it fixes an open issue, please link to the issue here. -->

---------

Signed-off-by: Yuan Yao <[email protected]>
Signed-off-by: Yash solanki <[email protected]>
}
builder
.Add("KVSeqLenExpanded = Unsqueeze(nonpad_kv_seqlen, One1D)") // [batch_size, 1]
.Add("Range = Range(Zero1D, KVSeqLen, One1D)") // [KVSeqLen,]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's caught by ORT, but ONNX checker does not complain about this for some reasons..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like RMSNorm: https://github.com/onnx/onnx/pull/7135/files (reference fix)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms how about this?

                .Const("Zero0D", (int64_t)(0))
                .Const("One0D", (int64_t)(1))
                .Add("KVSeqLen0D = Unsqueeze(KVSeqLen, Zero1D)")
                .Add("Range = Range(Zero0D, KVSeqLen0D, One0D)")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. We missed it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR to fix: #7240

yuanyao-nv added a commit that referenced this pull request Aug 19, 2025
### Description
In the Attentiion op definition, update the inputs to Range to be
scalars as opposed to 1-element vectors, as required by the Range op
spec.

### Motivation and Context
See discussion
[here](#7164 (comment)).

---------

Signed-off-by: Yuan Yao <[email protected]>
yuanyao-nv added a commit to yuanyao-nv/onnx that referenced this pull request Aug 19, 2025
### Description
In the Attentiion op definition, update the inputs to Range to be
scalars as opposed to 1-element vectors, as required by the Range op
spec.

### Motivation and Context
See discussion
[here](onnx#7164 (comment)).

---------

Signed-off-by: Yuan Yao <[email protected]>
xadupre pushed a commit to xadupre/onnx that referenced this pull request Aug 31, 2025
### Description
In the Attentiion op definition, update the inputs to Range to be
scalars as opposed to 1-element vectors, as required by the Range op
spec.

### Motivation and Context
See discussion
[here](onnx#7164 (comment)).

---------

Signed-off-by: Yuan Yao <[email protected]>
Signed-off-by: xadupre <[email protected]>
@justinchuby justinchuby added the topic: operator Issues related to ONNX operators label Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

topic: operator Issues related to ONNX operators

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants