Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 363094c

Browse files
committed
Update EXAONE 4.0 modeling code for main branch
1 parent 996e417 commit 363094c

File tree

2 files changed

+54
-191
lines changed

2 files changed

+54
-191
lines changed

‎src/transformers/models/exaone4/modeling_exaone4.py‎

Lines changed: 34 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
2222

23-
from functools import partial
2423
from typing import Callable, Optional, Union
2524

2625
import torch
@@ -31,7 +30,6 @@
3130
from ...generation import GenerationMixin
3231
from ...integrations import use_kernel_forward_from_hub
3332
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3533
from ...modeling_outputs import (
3634
BaseModelOutputWithPast,
3735
CausalLMOutputWithPast,
@@ -42,7 +40,8 @@
4240
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
4341
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4442
from ...processing_utils import Unpack
45-
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
43+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
44+
from ...utils.generic import check_model_inputs
4645
from .configuration_exaone4 import Exaone4Config
4746

4847

@@ -74,7 +73,7 @@ class Exaone4RotaryEmbedding(nn.Module):
7473
def __init__(self, config: Exaone4Config, device=None):
7574
super().__init__()
7675
# BC: "rope_type" was originally "type"
77-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
76+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
7877
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
7978
else:
8079
self.rope_type = "default"
@@ -158,7 +157,7 @@ def eager_attention_forward(
158157
attention_mask: Optional[torch.Tensor],
159158
scaling: float,
160159
dropout: float = 0.0,
161-
**kwargs,
160+
**kwargs: Unpack[TransformersKwargs],
162161
):
163162
key_states = repeat_kv(key, module.num_key_value_groups)
164163
value_states = repeat_kv(value, module.num_key_value_groups)
@@ -239,7 +238,7 @@ def forward(
239238
attention_mask: Optional[torch.Tensor] = None,
240239
past_key_value: Optional[Cache] = None,
241240
cache_position: Optional[torch.LongTensor] = None,
242-
**kwargs: Unpack[FlashAttentionKwargs],
241+
**kwargs: Unpack[TransformersKwargs],
243242
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
244243
input_shape = hidden_states.shape[:-1]
245244
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -280,13 +279,7 @@ def forward(
280279

281280
attention_interface: Callable = eager_attention_forward
282281
if self.config._attn_implementation != "eager":
283-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
284-
logger.warning_once(
285-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
286-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
287-
)
288-
else:
289-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
282+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
290283

291284
attn_output, attn_weights = attention_interface(
292285
self,
@@ -337,11 +330,6 @@ def __init__(self, config: Exaone4Config, layer_idx: int):
337330

338331
self.is_sliding = check_is_sliding(config, layer_idx)
339332
self.sliding_window = config.sliding_window
340-
if config.sliding_window and config._attn_implementation == "sdpa":
341-
logger.warning_once(
342-
f"Sliding Window Attention is enabled but not optimized for `{config._attn_implementation}`; "
343-
"unexpected results may be encountered."
344-
)
345333

346334
def forward(
347335
self,
@@ -350,28 +338,26 @@ def forward(
350338
attention_mask: Optional[torch.Tensor] = None,
351339
position_ids: Optional[torch.LongTensor] = None,
352340
past_key_value: Optional[Cache] = None,
353-
output_attentions: Optional[bool] = False,
354341
use_cache: Optional[bool] = False,
355342
cache_position: Optional[torch.LongTensor] = None,
356-
**kwargs: Unpack[FlashAttentionKwargs],
343+
**kwargs: Unpack[TransformersKwargs],
357344
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
358345
residual = hidden_states
359346

360347
# Self Attention
361-
hidden_states, attn_weights = self.self_attn(
348+
hidden_states, _ = self.self_attn(
362349
hidden_states=hidden_states,
363350
position_embeddings=position_embeddings,
364351
attention_mask=attention_mask,
365352
position_ids=position_ids,
366353
past_key_value=past_key_value,
367-
output_attentions=output_attentions,
368354
use_cache=use_cache,
369355
cache_position=cache_position,
356+
**kwargs,
370357
)
371358

372359
# Use post-LN
373360
hidden_states = self.post_attention_layernorm(hidden_states)
374-
375361
hidden_states = residual + hidden_states
376362

377363
residual = hidden_states
@@ -381,14 +367,9 @@ def forward(
381367

382368
# Use post-LN
383369
hidden_states = self.post_feedforward_layernorm(hidden_states)
384-
385370
hidden_states = residual + hidden_states
386371

387-
outputs = (hidden_states,)
388-
if output_attentions:
389-
outputs += (attn_weights,)
390-
391-
return outputs
372+
return hidden_states
392373

393374

394375
@auto_docstring
@@ -398,14 +379,17 @@ class Exaone4PreTrainedModel(PreTrainedModel):
398379
supports_gradient_checkpointing = True
399380
_no_split_modules = ["Exaone4DecoderLayer"]
400381
_skip_keys_device_placement = ["past_key_values"]
401-
_supports_flash_attn_3 = True
402382
_supports_flash_attn_2 = True
403383
_supports_sdpa = True
404384
_supports_flex_attn = True
405385
_supports_cache_class = True
406386
_supports_quantized_cache = True
407387
_supports_static_cache = True
408388
_supports_attention_backend = True
389+
_can_record_outputs = {
390+
"hidden_states": Exaone4DecoderLayer,
391+
"attentions": Exaone4Attention,
392+
}
409393

410394
def _init_weights(self, module):
411395
std = self.config.initializer_range
@@ -421,9 +405,6 @@ def _init_weights(self, module):
421405
module.weight.data.fill_(1.0)
422406

423407

424-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
425-
426-
427408
@auto_docstring
428409
class Exaone4Model(Exaone4PreTrainedModel):
429410
def __init__(self, config: Exaone4Config):
@@ -448,7 +429,7 @@ def get_input_embeddings(self):
448429
def set_input_embeddings(self, value):
449430
self.embed_tokens = value
450431

451-
@can_return_tuple
432+
@check_model_inputs
452433
@auto_docstring
453434
def forward(
454435
self,
@@ -458,15 +439,9 @@ def forward(
458439
past_key_values: Optional[Cache] = None,
459440
inputs_embeds: Optional[torch.FloatTensor] = None,
460441
use_cache: Optional[bool] = None,
461-
output_attentions: Optional[bool] = None,
462-
output_hidden_states: Optional[bool] = None,
463442
cache_position: Optional[torch.LongTensor] = None,
464-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
443+
**kwargs: Unpack[TransformersKwargs],
465444
) -> Union[tuple, BaseModelOutputWithPast]:
466-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
467-
output_hidden_states = (
468-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
469-
)
470445
use_cache = use_cache if use_cache is not None else self.config.use_cache
471446

472447
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -519,6 +494,7 @@ def forward(
519494
"attention_mask": attention_mask,
520495
"cache_position": cache_position,
521496
"past_key_values": past_key_values,
497+
"position_ids": position_ids,
522498
}
523499
# Create the masks
524500
causal_mask_mapping = {
@@ -532,55 +508,23 @@ def forward(
532508
# create position embeddings to be shared across the decoder layers
533509
position_embeddings = self.rotary_emb(hidden_states, position_ids)
534510

535-
# decoder layers
536-
all_hidden_states = () if output_hidden_states else None
537-
all_self_attns = () if output_attentions else None
538-
539511
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
540-
if output_hidden_states:
541-
all_hidden_states += (hidden_states,)
542-
543-
if self.gradient_checkpointing and self.training:
544-
layer_outputs = self._gradient_checkpointing_func(
545-
partial(decoder_layer.__call__, **flash_attn_kwargs),
546-
hidden_states,
547-
position_embeddings,
548-
causal_mask_mapping[decoder_layer.attention_type],
549-
position_ids,
550-
past_key_values,
551-
output_attentions,
552-
use_cache,
553-
cache_position,
554-
)
555-
else:
556-
layer_outputs = decoder_layer(
557-
hidden_states,
558-
position_embeddings=position_embeddings,
559-
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
560-
position_ids=position_ids,
561-
past_key_value=past_key_values,
562-
output_attentions=output_attentions,
563-
use_cache=use_cache,
564-
cache_position=cache_position,
565-
**flash_attn_kwargs,
566-
)
567-
568-
hidden_states = layer_outputs[0]
569-
570-
if output_attentions:
571-
all_self_attns += (layer_outputs[1],)
512+
hidden_states = decoder_layer(
513+
hidden_states,
514+
position_embeddings=position_embeddings,
515+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
516+
position_ids=position_ids,
517+
past_key_value=past_key_values,
518+
use_cache=use_cache,
519+
cache_position=cache_position,
520+
**kwargs,
521+
)
572522

573523
hidden_states = self.norm(hidden_states)
574524

575-
# add hidden states from the last decoder layer
576-
if output_hidden_states:
577-
all_hidden_states += (hidden_states,)
578-
579525
return BaseModelOutputWithPast(
580526
last_hidden_state=hidden_states,
581527
past_key_values=past_key_values if use_cache else None,
582-
hidden_states=all_hidden_states,
583-
attentions=all_self_attns,
584528
)
585529

586530

@@ -628,11 +572,9 @@ def forward(
628572
inputs_embeds: Optional[torch.FloatTensor] = None,
629573
labels: Optional[torch.LongTensor] = None,
630574
use_cache: Optional[bool] = None,
631-
output_attentions: Optional[bool] = None,
632-
output_hidden_states: Optional[bool] = None,
633575
cache_position: Optional[torch.LongTensor] = None,
634576
logits_to_keep: Union[int, torch.Tensor] = 0,
635-
**kwargs: Unpack[KwargsForCausalLM],
577+
**kwargs: Unpack[TransformersKwargs],
636578
) -> CausalLMOutputWithPast:
637579
r"""
638580
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -666,21 +608,13 @@ def forward(
666608
```
667609
668610
NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future."""
669-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
670-
output_hidden_states = (
671-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
672-
)
673-
674-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
675611
outputs: BaseModelOutputWithPast = self.model(
676612
input_ids=input_ids,
677613
attention_mask=attention_mask,
678614
position_ids=position_ids,
679615
past_key_values=past_key_values,
680616
inputs_embeds=inputs_embeds,
681617
use_cache=use_cache,
682-
output_attentions=output_attentions,
683-
output_hidden_states=output_hidden_states,
684618
cache_position=cache_position,
685619
**kwargs,
686620
)
@@ -744,8 +678,7 @@ def forward(
744678
inputs_embeds: Optional[torch.FloatTensor] = None,
745679
labels: Optional[torch.LongTensor] = None,
746680
use_cache: Optional[bool] = None,
747-
output_attentions: Optional[bool] = None,
748-
output_hidden_states: Optional[bool] = None,
681+
**kwargs: Unpack[TransformersKwargs],
749682
) -> SequenceClassifierOutputWithPast:
750683
r"""
751684
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -761,8 +694,7 @@ def forward(
761694
past_key_values=past_key_values,
762695
inputs_embeds=inputs_embeds,
763696
use_cache=use_cache,
764-
output_attentions=output_attentions,
765-
output_hidden_states=output_hidden_states,
697+
**kwargs,
766698
)
767699
hidden_states = transformer_outputs.last_hidden_state
768700
logits = self.score(hidden_states)
@@ -838,8 +770,7 @@ def forward(
838770
inputs_embeds: Optional[torch.FloatTensor] = None,
839771
labels: Optional[torch.LongTensor] = None,
840772
use_cache: Optional[bool] = None,
841-
output_attentions: Optional[bool] = None,
842-
output_hidden_states: Optional[bool] = None,
773+
**kwargs,
843774
) -> TokenClassifierOutput:
844775
r"""
845776
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -855,8 +786,7 @@ def forward(
855786
past_key_values=past_key_values,
856787
inputs_embeds=inputs_embeds,
857788
use_cache=use_cache,
858-
output_attentions=output_attentions,
859-
output_hidden_states=output_hidden_states,
789+
**kwargs,
860790
)
861791
sequence_output = outputs.last_hidden_state
862792
sequence_output = self.dropout(sequence_output)
@@ -903,18 +833,15 @@ def forward(
903833
inputs_embeds: Optional[torch.FloatTensor] = None,
904834
start_positions: Optional[torch.LongTensor] = None,
905835
end_positions: Optional[torch.LongTensor] = None,
906-
output_attentions: Optional[bool] = None,
907-
output_hidden_states: Optional[bool] = None,
908-
**kwargs,
836+
**kwargs: Unpack[TransformersKwargs],
909837
) -> QuestionAnsweringModelOutput:
910838
outputs: BaseModelOutputWithPast = self.transformer(
911839
input_ids,
912840
attention_mask=attention_mask,
913841
position_ids=position_ids,
914842
past_key_values=past_key_values,
915843
inputs_embeds=inputs_embeds,
916-
output_attentions=output_attentions,
917-
output_hidden_states=output_hidden_states,
844+
**kwargs,
918845
)
919846

920847
sequence_output = outputs.last_hidden_state

0 commit comments

Comments
 (0)