mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
fix
This commit is contained in:
@@ -123,6 +123,7 @@ class CommandPipelineForwards:
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
shard_config.enable_flash_attention = True
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
||||
@@ -391,6 +392,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = None
|
||||
|
||||
shard_config.enable_flash_attention = True
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
@@ -402,7 +405,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
dropout = (0.0 if not self.training else self.attention_dropout,)
|
||||
dropout = 0.0 if not self.training else self.attention_dropout
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
@@ -477,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
shard_config.enable_flash_attention = True
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
@@ -524,7 +529,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
|
@@ -75,14 +75,15 @@ class CommandPolicy(Policy):
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
|
Reference in New Issue
Block a user