mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 12:51:55 +00:00
fix
This commit is contained in:
parent
10bc6af2b1
commit
ced6b5e1c3
@ -123,6 +123,7 @@ class CommandPipelineForwards:
|
|||||||
|
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
|
shard_config.enable_flash_attention = True
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
||||||
@ -391,6 +392,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
|
shard_config.enable_flash_attention = True
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
@ -402,7 +405,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||||||
attn_weights = attn_weights + causal_mask
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
dropout = (0.0 if not self.training else self.attention_dropout,)
|
dropout = 0.0 if not self.training else self.attention_dropout
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
@ -477,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
shard_config.enable_flash_attention = True
|
||||||
|
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||||
@ -524,7 +529,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||||||
cache_position,
|
cache_position,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -75,14 +75,15 @@ class CommandPolicy(Policy):
|
|||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=attn_cls,
|
||||||
|
)
|
||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description={
|
|
||||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
|
||||||
},
|
|
||||||
policy=policy,
|
|
||||||
target_key=attn_cls,
|
|
||||||
)
|
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
|
@ -213,6 +213,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
@ -220,7 +231,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
@ -231,7 +241,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
@ -243,7 +252,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"enable_flash_attention": True,
|
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"enable_gradient_checkpointing": True,
|
"enable_gradient_checkpointing": True,
|
||||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||||
@ -252,7 +260,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
@ -263,7 +270,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
|
Loading…
Reference in New Issue
Block a user