mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 11:03:58 +00:00
fix
This commit is contained in:
parent
10bc6af2b1
commit
ced6b5e1c3
@ -123,6 +123,7 @@ class CommandPipelineForwards:
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
shard_config.enable_flash_attention = True
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
||||
@ -391,6 +392,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = None
|
||||
|
||||
shard_config.enable_flash_attention = True
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
@ -402,7 +405,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
dropout = (0.0 if not self.training else self.attention_dropout,)
|
||||
dropout = 0.0 if not self.training else self.attention_dropout
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
@ -477,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
shard_config.enable_flash_attention = True
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
@ -524,7 +529,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
|
@ -75,14 +75,15 @@ class CommandPolicy(Policy):
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
|
@ -213,6 +213,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
@ -220,7 +231,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
@ -231,7 +241,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
@ -243,7 +252,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"enable_flash_attention": True,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||
@ -252,7 +260,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
@ -263,7 +270,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
|
Loading…
Reference in New Issue
Block a user