mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
fix
This commit is contained in:
parent
ba9fb549d5
commit
10bc6af2b1
@ -134,6 +134,7 @@ class CommandPipelineForwards:
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
# v4.51.3 transformers attention_mask calculation
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
@ -164,7 +165,7 @@ class CommandPipelineForwards:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# v4.51.3 transformers position_embeddings calculation
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
@ -394,6 +395,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
else:
|
||||
# attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward.
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
@ -486,6 +488,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
# v4.51.3 transformers attention_mask calculation
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
@ -503,6 +506,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# v4.51.3 transformers position_embeddings calculation
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
|
Loading…
Reference in New Issue
Block a user