add explanation

This commit is contained in:
wangbluo 2025-05-08 17:46:54 +08:00
parent 885210dc27
commit cefdfc4125

View File

@ -111,7 +111,6 @@ def get_tp_falcon_decoder_layer_forward():
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
):
residual = hidden_states
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
@ -196,6 +195,8 @@ class FalconPipelineForwards:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
# Add cache_position and position_embeddings args for v4.51.3 transformers
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -261,7 +262,8 @@ class FalconPipelineForwards:
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# use new version of causal mask construction.
# In v4.51.3 version, sdpa, egaer and flash attention are merged into one class.
causal_mask = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
)