add explanation

This commit is contained in:
wangbluo 2025-05-08 17:46:54 +08:00
parent 885210dc27
commit cefdfc4125

View File

@ -111,7 +111,6 @@ def get_tp_falcon_decoder_layer_forward():
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs, **kwargs,
): ):
residual = hidden_states residual = hidden_states
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
@ -196,6 +195,8 @@ class FalconPipelineForwards:
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
# Add cache_position and position_embeddings args for v4.51.3 transformers
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -261,7 +262,8 @@ class FalconPipelineForwards:
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# use new version of causal mask construction.
# In v4.51.3 version, sdpa, egaer and flash attention are merged into one class.
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
) )