mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-28 20:53:58 +00:00
add explanation
This commit is contained in:
parent
885210dc27
commit
cefdfc4125
@ -111,7 +111,6 @@ def get_tp_falcon_decoder_layer_forward():
|
|||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
|
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
|
||||||
@ -196,6 +195,8 @@ class FalconPipelineForwards:
|
|||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
# Add cache_position and position_embeddings args for v4.51.3 transformers
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -261,7 +262,8 @@ class FalconPipelineForwards:
|
|||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
# use new version of causal mask construction.
|
||||||
|
# In v4.51.3 version, sdpa, egaer and flash attention are merged into one class.
|
||||||
causal_mask = self._update_causal_mask(
|
causal_mask = self._update_causal_mask(
|
||||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
|
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user