Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@lgai-exaone
Copy link
Contributor

What does this PR do?

Add EXAONE 4.0 modeling code in preparation for official model release by LG AI Research.

This PR adds the modeling code for EXAONE 4.0 ahead of the official model release by LG AI Research.
Model weights, test code, and documentation will be updated once the official release is available.

This contribution is licensed under the MIT License. Please refer to the attached NOTICE.md for the license notice.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for using Modular!
Looks good, lets make sure we don't have code paths!

Comment on lines 261 to 304
@torch.jit.script
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to also import from llama as they are similar!

Comment on lines 463 to 468
if self.reorder_qk_norm:
self.post_attention_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.input_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the new model ahve this set to True or false!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In upcoming release, the model will use only reorder_qk_norm=True, so we're removing that config option and its related code paths!

) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
residual = hidden_states

# We use one of LN options:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here we avoid code paths as much as possible!

@lgai-exaone lgai-exaone force-pushed the add-exaone4 branch 2 times, most recently from 4d60c1a to 614dd4d Compare July 2, 2025 06:32
@lgai-exaone
Copy link
Contributor Author

We found an issue with the implementation of KV cache slicing when using FA2 & SWA, so we've updated relevant code.
Please check this update and let us know if there are any other issues!

@lgai-exaone
Copy link
Contributor Author

Hello, @ArthurZucker.
After several tests, we found that our recent commit did not fully resolve the issue and actually introduced another one.

During our investigation, we noticed that HybridCache.get_seq_length() uses only the 0-th layer, which can be problematic if the 0-th layer is 'sliding_attention', as it returns sliding_window_size instead of the actual seq_len.

Although we replaced past_key_value.get_seq_length() with cache_position[-1] + 1, this approach is still not optimized for CUDA Graph due to dynamic shaping.

We’d like to confirm whether this method is implemented as intended, or if there are plans to update it in the future. If an update is planned, the modeling code may also need to be updated.

@lgai-exaone lgai-exaone force-pushed the add-exaone4 branch 5 times, most recently from 363094c to 4db07b0 Compare July 8, 2025 01:11
@lgai-exaone
Copy link
Contributor Author

We have updated the modeling code with the recent commits, and it seems to be working well, except for a few failing cases caused by unrelated issues.
@ArthurZucker, Could you please review the code for merging?

@lgai-exaone
Copy link
Contributor Author

We are happy to announce that our EXAONE 4.0 models are released!
https://github.com/LG-AI-EXAONE/EXAONE-4.0

@ArthurZucker
Copy link
Collaborator

Sorry for my delayed review, on it! 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mervelous! Thanks a lot for using modular! 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this!

Comment on lines 346 to 347
"sin": sin,
"cos": cos,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"sin": sin,
"cos": cos,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to remove sin/cos values from the cache in both cases of hybrid attention and global attention?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is deprecated as only sink cache use to use it, a small nit!

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# sin and cos are specific to RoPE models; cache_position needed for the static cache

Comment on lines 353 to 356
# Here we need to slice as we use a static cache by default, but FA2 does not support it
# attention_mask can be None, so we use cache_position rather than attention_mask's shape
# NOTE: seq_len can be retrieved from past_key_value.get_seq_length(),
# but currently, only 0th-layer is used for .get_seq_length() in HybridCache.
# This can cause issues when the 0th-layer is sliding window and seq_len > window_size,
# as it affects full attention layers by slicing KV cache improperly.
# Dynamic calculation of seq_len is not optimal for CUDAGraph, thus it seems to be updated later.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is interesting, but not expected, we should do this directly in the static cache in general or in the FlashAttention path if possible!

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)


class Exaone4DecoderLayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as Olmo2 you can use it for inheritance!

Comment on lines 494 to 497
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
if self.config.sliding_window is None:
past_key_values = StaticCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
else:
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general we should just set cache_type to be sliding or hybrid in the config based on sliding window as this should only be used for model.generate!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we understand it, should we use only HybridCache in this section? We have two types of models mentioned here, so we thought we need to choose the proper cache type based on the model being used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is a limitation on our side for now, but dynamic is the default cache so let's go with this! Generate should choose the hybrid vs static itself!

Comment on lines 536 to 524
if self.config.sliding_window is not None:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing some models are not sliding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. EXAONE 4.0 32B uses a hybrid architecture with a 1:3 local-global attention scheme, whereas the 1.2B model uses only global attention (full attention). Therefore, we need to consider both cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

@ArthurZucker
Copy link
Collaborator

Also just waiting on the .md to be updated!

@ArthurZucker
Copy link
Collaborator

@lgai-exaone sorry for the delay

@lgai-exaone
Copy link
Contributor Author

Thank you for your consideration, @ArthurZucker!

Please review our replies to resolve the conversations :)

@ArthurZucker
Copy link
Collaborator

Answered everything!

@lgai-exaone lgai-exaone force-pushed the add-exaone4 branch 4 times, most recently from 8c251b2 to 58664a7 Compare July 23, 2025 01:17
@lkm2835
Copy link
Contributor

lkm2835 commented Jul 25, 2025

Hi @ArthurZucker, thanks for your comments!
We've updated code and docs.
What are the next steps for merging? vLLM has already merged it!
vllm-project/vllm#21060

@ArthurZucker
Copy link
Collaborator

On it!

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, deepseek_vl, deepseek_vl_hybrid, evolla, exaone4

@ArthurZucker
Copy link
Collaborator

@lgai-exaone made a few changes, the layer_type is easiest to deal with, sliding_pattern is 4 instead of lllg we are trying to standardize!

@ArthurZucker ArthurZucker merged commit c06d4cd into huggingface:main Jul 25, 2025
20 of 23 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lgai-exaone
Copy link
Contributor Author

Hello, @ArthurZucker. Thank you for your contribution!

We appreciate your hard work and will update our documentation with the transformers release. :)
We've started reviewing your latest changes to ensure the implementation aligns with our intentions.
Please review our feedback - we would be grateful for your response!

@ArthurZucker
Copy link
Collaborator

Sure! Where can I see your review? 🤗

Comment on lines +216 to +217
sliding_window=4096,
sliding_window_pattern=4,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to remove None as the default from the config? We intended for a newly initialized model to use full attention without any sliding window attention.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely not! We can use None here!

for i in range(self.num_hidden_layers)
]
if "sliding_window" in self.layer_types:
self._attn_implementation = "hybrid"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears to be fixed. It would be better to continue the discussion on #39698

Comment on lines +384 to +385
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sliding window is set as the default, then it should be changed to HybridCache for consistency. Otherwise, we may need to change the default sliding window options in the config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cache_implementation in the config is supposed to do this!

def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some
# unoptimal memory usage, e.g. after certain teruff format examples tests src utilssts a 7B model in FP16 no longer fits in a 24GB GPU.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a typo? It seems like ruff format examples tests src utils can be removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

@lgai-exaone
Copy link
Contributor Author

Sure! Where can I see your review? 🤗

Sorry for not submitting our reviews earlier. We have just submitted them.
Thank you for your efforts!

@ArthurZucker
Copy link
Collaborator

No worries 🫡

zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* Add EXAONE 4.0 model

* Refactor EXAONE 4.0 modeling code

* Fix cache slicing on SWA + FA2

* Fix cache slicing on FA2 + HybridCache

* Update EXAONE 4.0 modeling code for main branch

* Update o_proj for asymmetric projection

* Address PR feedback

* Add EXAONE 4.0 docs

* Update EXAONE 4.0 modeling code for main branch

* update

* fix updates

* updates

* fix

* fix

* fix

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
@ydshieh
Copy link
Collaborator

ydshieh commented Sep 23, 2025

Hi @lgai-exaone

We have

(line 508) OSError: LGAI-EXAONE/EXAONE-4.0-Instruct is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'

on our CI, since this model is added to the library.

I see

TEST_MODEL_ID = "LGAI-EXAONE/EXAONE-4.0-Instruct"  # dummy model id

Did you intend to use a dummy model? In any case, could you provide a repo id that could be used for testing?

@lkm2835 lkm2835 mentioned this pull request Sep 23, 2025
5 tasks
@lkm2835
Copy link
Contributor

lkm2835 commented Sep 23, 2025

Hi @ydshieh
I fixed the dummy model to LGAI-EXAONE/EXAONE-4.0-32B.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants