mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
fix
This commit is contained in:
parent
ba9fb549d5
commit
10bc6af2b1
@ -134,6 +134,7 @@ class CommandPipelineForwards:
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# v4.51.3 transformers attention_mask calculation
|
||||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
@ -164,7 +165,7 @@ class CommandPipelineForwards:
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
# v4.51.3 transformers position_embeddings calculation
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
@ -394,6 +395,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
else:
|
else:
|
||||||
|
# attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward.
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
@ -486,6 +488,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# v4.51.3 transformers attention_mask calculation
|
||||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
@ -503,6 +506,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
# v4.51.3 transformers position_embeddings calculation
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
|
Loading…
Reference in New Issue
Block a user