2020# See the License for the specific language governing permissions and
2121# limitations under the License.
2222
23- from functools import partial
2423from typing import Callable , Optional , Union
2524
2625import torch
3130from ...generation import GenerationMixin
3231from ...integrations import use_kernel_forward_from_hub
3332from ...masking_utils import create_causal_mask , create_sliding_window_causal_mask
34- from ...modeling_flash_attention_utils import FlashAttentionKwargs
3533from ...modeling_outputs import (
3634 BaseModelOutputWithPast ,
3735 CausalLMOutputWithPast ,
4240from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
4341from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
4442from ...processing_utils import Unpack
45- from ...utils import LossKwargs , auto_docstring , can_return_tuple , logging
43+ from ...utils import TransformersKwargs , auto_docstring , can_return_tuple , logging
44+ from ...utils .generic import check_model_inputs
4645from .configuration_exaone4 import Exaone4Config
4746
4847
@@ -74,7 +73,7 @@ class Exaone4RotaryEmbedding(nn.Module):
7473 def __init__ (self , config : Exaone4Config , device = None ):
7574 super ().__init__ ()
7675 # BC: "rope_type" was originally "type"
77- if hasattr (config , "rope_scaling" ) and config .rope_scaling is not None :
76+ if hasattr (config , "rope_scaling" ) and isinstance ( config .rope_scaling , dict ) :
7877 self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
7978 else :
8079 self .rope_type = "default"
@@ -158,7 +157,7 @@ def eager_attention_forward(
158157 attention_mask : Optional [torch .Tensor ],
159158 scaling : float ,
160159 dropout : float = 0.0 ,
161- ** kwargs ,
160+ ** kwargs : Unpack [ TransformersKwargs ] ,
162161):
163162 key_states = repeat_kv (key , module .num_key_value_groups )
164163 value_states = repeat_kv (value , module .num_key_value_groups )
@@ -239,7 +238,7 @@ def forward(
239238 attention_mask : Optional [torch .Tensor ] = None ,
240239 past_key_value : Optional [Cache ] = None ,
241240 cache_position : Optional [torch .LongTensor ] = None ,
242- ** kwargs : Unpack [FlashAttentionKwargs ],
241+ ** kwargs : Unpack [TransformersKwargs ],
243242 ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
244243 input_shape = hidden_states .shape [:- 1 ]
245244 hidden_shape = (* input_shape , - 1 , self .head_dim )
@@ -280,13 +279,7 @@ def forward(
280279
281280 attention_interface : Callable = eager_attention_forward
282281 if self .config ._attn_implementation != "eager" :
283- if self .config ._attn_implementation == "sdpa" and kwargs .get ("output_attentions" , False ):
284- logger .warning_once (
285- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
286- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
287- )
288- else :
289- attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
282+ attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
290283
291284 attn_output , attn_weights = attention_interface (
292285 self ,
@@ -337,11 +330,6 @@ def __init__(self, config: Exaone4Config, layer_idx: int):
337330
338331 self .is_sliding = check_is_sliding (config , layer_idx )
339332 self .sliding_window = config .sliding_window
340- if config .sliding_window and config ._attn_implementation == "sdpa" :
341- logger .warning_once (
342- f"Sliding Window Attention is enabled but not optimized for `{ config ._attn_implementation } `; "
343- "unexpected results may be encountered."
344- )
345333
346334 def forward (
347335 self ,
@@ -350,28 +338,26 @@ def forward(
350338 attention_mask : Optional [torch .Tensor ] = None ,
351339 position_ids : Optional [torch .LongTensor ] = None ,
352340 past_key_value : Optional [Cache ] = None ,
353- output_attentions : Optional [bool ] = False ,
354341 use_cache : Optional [bool ] = False ,
355342 cache_position : Optional [torch .LongTensor ] = None ,
356- ** kwargs : Unpack [FlashAttentionKwargs ],
343+ ** kwargs : Unpack [TransformersKwargs ],
357344 ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
358345 residual = hidden_states
359346
360347 # Self Attention
361- hidden_states , attn_weights = self .self_attn (
348+ hidden_states , _ = self .self_attn (
362349 hidden_states = hidden_states ,
363350 position_embeddings = position_embeddings ,
364351 attention_mask = attention_mask ,
365352 position_ids = position_ids ,
366353 past_key_value = past_key_value ,
367- output_attentions = output_attentions ,
368354 use_cache = use_cache ,
369355 cache_position = cache_position ,
356+ ** kwargs ,
370357 )
371358
372359 # Use post-LN
373360 hidden_states = self .post_attention_layernorm (hidden_states )
374-
375361 hidden_states = residual + hidden_states
376362
377363 residual = hidden_states
@@ -381,14 +367,9 @@ def forward(
381367
382368 # Use post-LN
383369 hidden_states = self .post_feedforward_layernorm (hidden_states )
384-
385370 hidden_states = residual + hidden_states
386371
387- outputs = (hidden_states ,)
388- if output_attentions :
389- outputs += (attn_weights ,)
390-
391- return outputs
372+ return hidden_states
392373
393374
394375@auto_docstring
@@ -398,14 +379,17 @@ class Exaone4PreTrainedModel(PreTrainedModel):
398379 supports_gradient_checkpointing = True
399380 _no_split_modules = ["Exaone4DecoderLayer" ]
400381 _skip_keys_device_placement = ["past_key_values" ]
401- _supports_flash_attn_3 = True
402382 _supports_flash_attn_2 = True
403383 _supports_sdpa = True
404384 _supports_flex_attn = True
405385 _supports_cache_class = True
406386 _supports_quantized_cache = True
407387 _supports_static_cache = True
408388 _supports_attention_backend = True
389+ _can_record_outputs = {
390+ "hidden_states" : Exaone4DecoderLayer ,
391+ "attentions" : Exaone4Attention ,
392+ }
409393
410394 def _init_weights (self , module ):
411395 std = self .config .initializer_range
@@ -421,9 +405,6 @@ def _init_weights(self, module):
421405 module .weight .data .fill_ (1.0 )
422406
423407
424- class KwargsForCausalLM (FlashAttentionKwargs , LossKwargs ): ...
425-
426-
427408@auto_docstring
428409class Exaone4Model (Exaone4PreTrainedModel ):
429410 def __init__ (self , config : Exaone4Config ):
@@ -448,7 +429,7 @@ def get_input_embeddings(self):
448429 def set_input_embeddings (self , value ):
449430 self .embed_tokens = value
450431
451- @can_return_tuple
432+ @check_model_inputs
452433 @auto_docstring
453434 def forward (
454435 self ,
@@ -458,15 +439,9 @@ def forward(
458439 past_key_values : Optional [Cache ] = None ,
459440 inputs_embeds : Optional [torch .FloatTensor ] = None ,
460441 use_cache : Optional [bool ] = None ,
461- output_attentions : Optional [bool ] = None ,
462- output_hidden_states : Optional [bool ] = None ,
463442 cache_position : Optional [torch .LongTensor ] = None ,
464- ** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
443+ ** kwargs : Unpack [TransformersKwargs ],
465444 ) -> Union [tuple , BaseModelOutputWithPast ]:
466- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
467- output_hidden_states = (
468- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
469- )
470445 use_cache = use_cache if use_cache is not None else self .config .use_cache
471446
472447 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -519,6 +494,7 @@ def forward(
519494 "attention_mask" : attention_mask ,
520495 "cache_position" : cache_position ,
521496 "past_key_values" : past_key_values ,
497+ "position_ids" : position_ids ,
522498 }
523499 # Create the masks
524500 causal_mask_mapping = {
@@ -532,55 +508,23 @@ def forward(
532508 # create position embeddings to be shared across the decoder layers
533509 position_embeddings = self .rotary_emb (hidden_states , position_ids )
534510
535- # decoder layers
536- all_hidden_states = () if output_hidden_states else None
537- all_self_attns = () if output_attentions else None
538-
539511 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
540- if output_hidden_states :
541- all_hidden_states += (hidden_states ,)
542-
543- if self .gradient_checkpointing and self .training :
544- layer_outputs = self ._gradient_checkpointing_func (
545- partial (decoder_layer .__call__ , ** flash_attn_kwargs ),
546- hidden_states ,
547- position_embeddings ,
548- causal_mask_mapping [decoder_layer .attention_type ],
549- position_ids ,
550- past_key_values ,
551- output_attentions ,
552- use_cache ,
553- cache_position ,
554- )
555- else :
556- layer_outputs = decoder_layer (
557- hidden_states ,
558- position_embeddings = position_embeddings ,
559- attention_mask = causal_mask_mapping [decoder_layer .attention_type ],
560- position_ids = position_ids ,
561- past_key_value = past_key_values ,
562- output_attentions = output_attentions ,
563- use_cache = use_cache ,
564- cache_position = cache_position ,
565- ** flash_attn_kwargs ,
566- )
567-
568- hidden_states = layer_outputs [0 ]
569-
570- if output_attentions :
571- all_self_attns += (layer_outputs [1 ],)
512+ hidden_states = decoder_layer (
513+ hidden_states ,
514+ position_embeddings = position_embeddings ,
515+ attention_mask = causal_mask_mapping [decoder_layer .attention_type ],
516+ position_ids = position_ids ,
517+ past_key_value = past_key_values ,
518+ use_cache = use_cache ,
519+ cache_position = cache_position ,
520+ ** kwargs ,
521+ )
572522
573523 hidden_states = self .norm (hidden_states )
574524
575- # add hidden states from the last decoder layer
576- if output_hidden_states :
577- all_hidden_states += (hidden_states ,)
578-
579525 return BaseModelOutputWithPast (
580526 last_hidden_state = hidden_states ,
581527 past_key_values = past_key_values if use_cache else None ,
582- hidden_states = all_hidden_states ,
583- attentions = all_self_attns ,
584528 )
585529
586530
@@ -628,11 +572,9 @@ def forward(
628572 inputs_embeds : Optional [torch .FloatTensor ] = None ,
629573 labels : Optional [torch .LongTensor ] = None ,
630574 use_cache : Optional [bool ] = None ,
631- output_attentions : Optional [bool ] = None ,
632- output_hidden_states : Optional [bool ] = None ,
633575 cache_position : Optional [torch .LongTensor ] = None ,
634576 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
635- ** kwargs : Unpack [KwargsForCausalLM ],
577+ ** kwargs : Unpack [TransformersKwargs ],
636578 ) -> CausalLMOutputWithPast :
637579 r"""
638580 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -666,21 +608,13 @@ def forward(
666608 ```
667609
668610 NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future."""
669- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
670- output_hidden_states = (
671- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
672- )
673-
674- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
675611 outputs : BaseModelOutputWithPast = self .model (
676612 input_ids = input_ids ,
677613 attention_mask = attention_mask ,
678614 position_ids = position_ids ,
679615 past_key_values = past_key_values ,
680616 inputs_embeds = inputs_embeds ,
681617 use_cache = use_cache ,
682- output_attentions = output_attentions ,
683- output_hidden_states = output_hidden_states ,
684618 cache_position = cache_position ,
685619 ** kwargs ,
686620 )
@@ -744,8 +678,7 @@ def forward(
744678 inputs_embeds : Optional [torch .FloatTensor ] = None ,
745679 labels : Optional [torch .LongTensor ] = None ,
746680 use_cache : Optional [bool ] = None ,
747- output_attentions : Optional [bool ] = None ,
748- output_hidden_states : Optional [bool ] = None ,
681+ ** kwargs : Unpack [TransformersKwargs ],
749682 ) -> SequenceClassifierOutputWithPast :
750683 r"""
751684 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -761,8 +694,7 @@ def forward(
761694 past_key_values = past_key_values ,
762695 inputs_embeds = inputs_embeds ,
763696 use_cache = use_cache ,
764- output_attentions = output_attentions ,
765- output_hidden_states = output_hidden_states ,
697+ ** kwargs ,
766698 )
767699 hidden_states = transformer_outputs .last_hidden_state
768700 logits = self .score (hidden_states )
@@ -838,8 +770,7 @@ def forward(
838770 inputs_embeds : Optional [torch .FloatTensor ] = None ,
839771 labels : Optional [torch .LongTensor ] = None ,
840772 use_cache : Optional [bool ] = None ,
841- output_attentions : Optional [bool ] = None ,
842- output_hidden_states : Optional [bool ] = None ,
773+ ** kwargs ,
843774 ) -> TokenClassifierOutput :
844775 r"""
845776 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -855,8 +786,7 @@ def forward(
855786 past_key_values = past_key_values ,
856787 inputs_embeds = inputs_embeds ,
857788 use_cache = use_cache ,
858- output_attentions = output_attentions ,
859- output_hidden_states = output_hidden_states ,
789+ ** kwargs ,
860790 )
861791 sequence_output = outputs .last_hidden_state
862792 sequence_output = self .dropout (sequence_output )
@@ -903,18 +833,15 @@ def forward(
903833 inputs_embeds : Optional [torch .FloatTensor ] = None ,
904834 start_positions : Optional [torch .LongTensor ] = None ,
905835 end_positions : Optional [torch .LongTensor ] = None ,
906- output_attentions : Optional [bool ] = None ,
907- output_hidden_states : Optional [bool ] = None ,
908- ** kwargs ,
836+ ** kwargs : Unpack [TransformersKwargs ],
909837 ) -> QuestionAnsweringModelOutput :
910838 outputs : BaseModelOutputWithPast = self .transformer (
911839 input_ids ,
912840 attention_mask = attention_mask ,
913841 position_ids = position_ids ,
914842 past_key_values = past_key_values ,
915843 inputs_embeds = inputs_embeds ,
916- output_attentions = output_attentions ,
917- output_hidden_states = output_hidden_states ,
844+ ** kwargs ,
918845 )
919846
920847 sequence_output = outputs .last_hidden_state
0 commit comments