[bugfix] colo attn bug fix

This commit is contained in:
haze188 2024-07-24 06:53:24 +00:00
parent e521890d32
commit 2d73efdfdd
5 changed files with 106 additions and 55 deletions

View File

@ -73,8 +73,8 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
moe_dp_group: list(filter(is_moe_tensor, model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
} }
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: # if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in dp_process_group or moe_dp_group") # raise ValueError("No parameters found in dp_process_group or moe_dp_group")
super().__init__( super().__init__(
model=model, model=model,

View File

@ -1,7 +1,9 @@
import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -34,6 +36,8 @@ from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
from ..layer import ColoAttention
# copied from modeling_deepseek.py # copied from modeling_deepseek.py
class AddAuxiliaryLoss(torch.autograd.Function): class AddAuxiliaryLoss(torch.autograd.Function):
@ -529,34 +533,30 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size() # 1 4, 32
# sp: modify sp_len when sequence parallel mode is ring # sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]: if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size q_len *= sp_size
import torch.distributed as dist
rank = dist.get_rank() dist.get_rank()
print(f"{rank=}, hidden states:{hidden_states.shape}")
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
rank = dist.get_rank()
print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group) query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group) key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group) value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# Flash attention requires the input to have the shape # Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim # batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape # therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
@ -565,7 +565,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0 query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
) )
print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
@ -573,13 +572,11 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view. # to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) # query_states = query_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) # key_states = key_states.transpose(1, 2)
print( # value_states = value_states.transpose(1, 2)
f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}" # dropout_rate = self.attention_dropout if self.training else 0.0
)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
@ -606,22 +603,57 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states = query_states.to(target_dtype) query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") # attn_output = self._flash_attention_forward(
attn_output = self._flash_attention_forward( # query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate # )
)
if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
# print(f"{rank=}, shard attn output after all to all:{attn_output[0][0]}")
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# print(f"{rank=}, {attn_output[0][0]}")
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
import torch.distributed as dist
dist.get_rank()
# print(f"{rank=}, {attn_output[0][0]}")
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
return forward return forward
@ -683,24 +715,38 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if shard_config.enable_flash_attention:
if self._use_flash_attention_2: mask_shape = (
# 2d mask is passed through the layers inputs_embeds.shape[0],
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1,
elif self._use_sdpa and not output_attentions: past_key_values_length + inputs_embeds.shape[1],
# output_attentions=True can not be supported when using SDPA, and we fall back on past_key_values_length + inputs_embeds.shape[1],
# the manual implementation that requires a 4D causal mask in all cases. )
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask = ColoAttention.prepare_attn_kwargs(
attention_mask, mask_shape,
(batch_size, seq_length), inputs_embeds.dtype,
inputs_embeds, inputs_embeds.device,
past_key_values_length, q_padding_mask=attention_mask,
is_causal=True,
) )
else: else:
# 4d mask is passed through the layers if self._use_flash_attention_2:
attention_mask = _prepare_4d_causal_attention_mask( # 2d mask is passed through the layers
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
) elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
if sp_mode in ["ring", "split_gather"]: if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
@ -714,7 +760,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
for decoder_layer in self.layers: for i, decoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -746,8 +792,10 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
# import torch.distributed as dist
# rank = dist.get_rank()
# print(f"{rank=}, {hidden_states[0][0]}")
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
@ -194,11 +193,11 @@ class DeepseekPolicy(Policy):
target_key="DeepseekModel", target_key="DeepseekModel",
) )
if self.shard_config.enable_flash_attention: # if self.shard_config.enable_flash_attention:
warnings.warn( # warnings.warn(
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." # "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
) # )
self.shard_config.enable_flash_attention = False # self.shard_config.enable_flash_attention = False
return policy return policy

View File

@ -59,7 +59,7 @@ def init_deepseek():
num_attention_heads=8, num_attention_heads=8,
num_key_value_heads=8, num_key_value_heads=8,
# vocab_size=2200, # vocab_size=2200,
first_k_dense_replace=1, first_k_dense_replace=2,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
torch_dtype="float16", torch_dtype="float16",
n_routed_experts=8, n_routed_experts=8,
@ -68,7 +68,6 @@ def init_deepseek():
if hasattr(config, "pad_token_id"): if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
print(config)
model = transformers.AutoModel.from_config(config, trust_remote_code=True) model = transformers.AutoModel.from_config(config, trust_remote_code=True)
return model return model

View File

@ -30,7 +30,12 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp # TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD # model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
model_fn,
loss_fn,
test_config,
pluggin_cls=MoeHybridParallelPlugin,
optim_class=torch.optim.SGD,
) )
org_model = org_model.to(torch.float16) org_model = org_model.to(torch.float16)
@ -39,16 +44,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
stage_manager = booster.plugin.stage_manager stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group tp_group = booster.plugin.tp_group
rank = dist.get_rank()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol) check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model # unwrap model
mixtral_model = unwrap_model(org_model, "DeepseekModel", "model") mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
@ -178,12 +182,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"sp_size": 2, "sp_size": 2,
"ep_size": 2, "ep_size": 2,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"enable_flash_attention": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",
"zero_stage": 1, "zero_stage": 1,
"overlap_communication": False, "overlap_communication": False,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
"find_unused_parameters": True, # "find_unused_parameters": True,
}, },
# { # {
# "tp_size": 1, # "tp_size": 1,
@ -224,7 +229,7 @@ def check_deepseek(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()
def test_mixtral(): def test_mixtral():
spawn(check_deepseek, 4) spawn(check_deepseek, 2)
if __name__ == "__main__": if __name__ == "__main__":