mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-21 10:50:56 +00:00
[moe] deepseek moe sp support
This commit is contained in:
parent
96d0fbc531
commit
b2952a5982
@ -1,12 +1,18 @@
|
|||||||
from typing import List, Optional
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
from transformers.utils import is_flash_attn_2_available, logging
|
from transformers.utils import is_flash_attn_2_available, logging
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
@ -18,6 +24,11 @@ from colossalai.moe._operation import (
|
|||||||
all_to_all_uneven,
|
all_to_all_uneven,
|
||||||
)
|
)
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer._operation import (
|
||||||
|
all_to_all_comm,
|
||||||
|
gather_forward_split_backward,
|
||||||
|
split_forward_gather_backward,
|
||||||
|
)
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
@ -362,7 +373,14 @@ class DeepseekPipelineForwards:
|
|||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
# always return dict for imediate stage
|
# always return dict for imediate stage
|
||||||
return {
|
return {
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -479,3 +497,276 @@ class DeepseekPipelineForwards:
|
|||||||
hidden_states = outputs.get("hidden_states")
|
hidden_states = outputs.get("hidden_states")
|
||||||
out["hidden_states"] = hidden_states
|
out["hidden_states"] = hidden_states
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if sp_mode is not None:
|
||||||
|
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||||
|
assert (sp_size is not None) and (
|
||||||
|
sp_group is not None
|
||||||
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
|
|
||||||
|
# DeepseekFlashAttention2 attention does not support output_attentions
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite attention_mask with padding_mask
|
||||||
|
attention_mask = kwargs.pop("padding_mask")
|
||||||
|
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
q_len *= sp_size
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
print(f"{rank=}, hidden states:{hidden_states.shape}")
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
|
bsz, q_len, _ = query_states.size()
|
||||||
|
print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
|
||||||
|
)
|
||||||
|
print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
print(
|
||||||
|
f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
|
||||||
|
)
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (DeepseekRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
elif torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
|
)
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||||
|
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||||
|
else:
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._use_sdpa and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
@ -7,8 +7,14 @@ from torch import Tensor
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||||
|
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||||
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
|
from colossalai.shardformer.modeling.deepseek import (
|
||||||
|
DeepseekPipelineForwards,
|
||||||
|
EPDeepseekMoE,
|
||||||
|
get_deepseek_flash_attention_forward,
|
||||||
|
get_deepseek_flash_attention_model_forward,
|
||||||
|
)
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
||||||
@ -19,6 +25,13 @@ class DeepseekPolicy(Policy):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
|
self.tie_weight = self.tie_weight_check()
|
||||||
|
self.origin_attn_implement = self.model.config._attn_implementation
|
||||||
|
"""
|
||||||
|
Because transformers library's bug for AutoModel/AutoConfig, who pop “attn_implement” twice from modeling_utils.py and configuration_utils.py.
|
||||||
|
This bug causes attn_cls to be set to sdpa. Here we assign it to "flash_attention_2".
|
||||||
|
"""
|
||||||
|
# self.origin_attn_implement = "flash_attention_2"
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# Resize embedding
|
# Resize embedding
|
||||||
vocab_size = self.model.config.vocab_size
|
vocab_size = self.model.config.vocab_size
|
||||||
@ -31,17 +44,61 @@ class DeepseekPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
policy = {}
|
|
||||||
|
|
||||||
|
ATTN_IMPLEMENTATION = {
|
||||||
|
"eager": "DeepseekAttention",
|
||||||
|
"flash_attention_2": "DeepseekFlashAttention2",
|
||||||
|
"sdpa": "DeepseekSdpaAttention",
|
||||||
|
}
|
||||||
|
policy = {}
|
||||||
|
print(f"{self.origin_attn_implement=}")
|
||||||
|
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||||
|
|
||||||
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
)
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||||
# if both are enabled, one of them will be ignored
|
# if both are enabled, one of them will be ignored
|
||||||
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
|
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
|
||||||
raise NotImplementedError(
|
print(f"{attn_cls=}")
|
||||||
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=attn_cls,
|
||||||
)
|
)
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_deepseek_flash_attention_model_forward(
|
||||||
|
self.shard_config,
|
||||||
|
sp_mode=sp_mode,
|
||||||
|
sp_size=sp_size,
|
||||||
|
sp_group=sp_group,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key="DeepseekModel",
|
||||||
|
)
|
||||||
|
embedding_cls = None
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
|
else:
|
||||||
|
if self.tie_weight:
|
||||||
|
embedding_cls = PaddingEmbedding
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# tensor parallelism for non-moe params
|
# tensor parallelism for non-moe params
|
||||||
assert (
|
assert (
|
||||||
@ -78,6 +135,16 @@ class DeepseekPolicy(Policy):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if embedding_cls is not None:
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
target_module=embedding_cls,
|
||||||
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key="DeepseekModel",
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.ep_group:
|
if self.shard_config.ep_group:
|
||||||
# expert parallel
|
# expert parallel
|
||||||
@ -105,10 +172,12 @@ class DeepseekPolicy(Policy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=FusedRMSNorm,
|
target_module=FusedRMSNorm,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="post_attention_layernorm",
|
suffix="post_attention_layernorm",
|
||||||
target_module=FusedRMSNorm,
|
target_module=FusedRMSNorm,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
@ -119,6 +188,7 @@ class DeepseekPolicy(Policy):
|
|||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="norm",
|
suffix="norm",
|
||||||
target_module=FusedRMSNorm,
|
target_module=FusedRMSNorm,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key="DeepseekModel",
|
target_key="DeepseekModel",
|
||||||
|
@ -4,6 +4,7 @@ from .blip2 import *
|
|||||||
from .bloom import *
|
from .bloom import *
|
||||||
from .chatglm2 import *
|
from .chatglm2 import *
|
||||||
from .command import *
|
from .command import *
|
||||||
|
from .deepseek import *
|
||||||
from .falcon import *
|
from .falcon import *
|
||||||
from .gpt import *
|
from .gpt import *
|
||||||
from .gptj import *
|
from .gptj import *
|
||||||
|
84
tests/kit/model_zoo/transformers/deepseek.py
Normal file
84
tests/kit/model_zoo/transformers/deepseek.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# modified from tests/kit/model_zoo/transformers/mistral.py
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# Register single-sentence Mixtral
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
# Generated from following code snippet
|
||||||
|
#
|
||||||
|
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
|
||||||
|
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
|
||||||
|
# tokenized_input = tokenizer([input], return_tensors="pt")
|
||||||
|
# input_ids = tokenized_input['input_ids']
|
||||||
|
# attention_mask = tokenized_input['attention_mask']
|
||||||
|
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_lm():
|
||||||
|
# LM data gen
|
||||||
|
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||||
|
data = data_gen()
|
||||||
|
data["labels"] = data["input_ids"].clone()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_sequence_classification():
|
||||||
|
# sequence classification data gen
|
||||||
|
data = data_gen()
|
||||||
|
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# define output transform function
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# define loss function
|
||||||
|
loss_fn_for_mixtral_model = lambda x: x[0].mean()
|
||||||
|
loss_fn = lambda x: x.loss
|
||||||
|
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def init_deepseek():
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=32,
|
||||||
|
moe_intermediate_size=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
# vocab_size=2200,
|
||||||
|
first_k_dense_replace=1,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype="float16",
|
||||||
|
n_routed_experts=8,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "pad_token_id"):
|
||||||
|
config.pad_token_id = config.eos_token_id
|
||||||
|
print(config)
|
||||||
|
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_deepseek",
|
||||||
|
model_fn=init_deepseek,
|
||||||
|
data_gen_fn=data_gen,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_mixtral_model,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
@ -36,8 +36,8 @@ CHECKED_CONFIG = [ # FOR_WORLD=8
|
|||||||
[
|
[
|
||||||
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
|
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||||
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
|
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
|
||||||
(2, 1, 1, 2, 1),
|
# (2, 1, 1, 2, 1),
|
||||||
# (2, 1, 1, 1, 2), # TODO support deepseek sp
|
(2, 1, 1, 1, 2),
|
||||||
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
|
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
|
||||||
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
|
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||||
],
|
],
|
||||||
@ -69,14 +69,22 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||||
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
# config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
||||||
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
config = AutoConfig.from_pretrained(
|
||||||
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
config.num_hidden_layers = 2
|
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||||
config.num_attention_heads = NUM_HEADS
|
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||||
config.num_key_value_heads = NUM_HEADS
|
moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||||
config.n_routed_experts = NUM_EXPERTS
|
num_hidden_layers=2,
|
||||||
config.num_experts_per_tok = TOP_K
|
num_attention_heads=NUM_HEADS,
|
||||||
|
num_key_value_heads=NUM_HEADS,
|
||||||
|
first_k_dense_replace=1,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype="float16",
|
||||||
|
n_routed_experts=NUM_EXPERTS,
|
||||||
|
num_experts_per_tok=TOP_K,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
# init model with the same seed
|
# init model with the same seed
|
||||||
seed_all(10086)
|
seed_all(10086)
|
||||||
|
231
tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
Normal file
231
tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# modified from test_shard_mistral.py
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_all_grad_tensors,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
get_grad_tensors_for_check,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
# TODO: SGD failed for full dp
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||||
|
)
|
||||||
|
|
||||||
|
org_model = org_model.to(torch.float16)
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
|
mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
|
||||||
|
shard_mixtral_model = unwrap_model(sharded_model, "DeepseekModel", "model")
|
||||||
|
|
||||||
|
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||||
|
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||||
|
|
||||||
|
name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
|
||||||
|
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||||
|
if (
|
||||||
|
# booster.plugin.zero_stage in [1, 2]
|
||||||
|
booster.plugin.shard_config.enable_sequence_parallelism
|
||||||
|
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||||
|
):
|
||||||
|
rank = dist.get_rank()
|
||||||
|
for n, p in shard_mixtral_model.named_parameters():
|
||||||
|
zero_grad = sharded_optimizer.get_param_grad(p)
|
||||||
|
if name_to_p[n].grad is None:
|
||||||
|
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
||||||
|
continue
|
||||||
|
assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
|
grads_to_check = {}
|
||||||
|
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 5e-5, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
|
mixtral_model,
|
||||||
|
shard_mixtral_model,
|
||||||
|
row_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=0,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
|
mixtral_model,
|
||||||
|
shard_mixtral_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
grads_to_check.update(col_layer_grads)
|
||||||
|
grads_to_check.update(row_layer_grads)
|
||||||
|
|
||||||
|
# check grads
|
||||||
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
for n, p in shard_mixtral_model.named_parameters():
|
||||||
|
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
|
# optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
for n, p in shard_mixtral_model.named_parameters():
|
||||||
|
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
|
# check weights
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 2e-4, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
try:
|
||||||
|
check_weight(
|
||||||
|
mixtral_model,
|
||||||
|
shard_mixtral_model,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
rank = dist.get_rank()
|
||||||
|
print(f"{rank=}, Failed config: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
# {
|
||||||
|
# "tp_size": 1,
|
||||||
|
# "pp_size": 1,
|
||||||
|
# "num_microbatches": 2,
|
||||||
|
# "ep_size": 2,
|
||||||
|
# "zero_stage": 0,
|
||||||
|
# "overlap_communication": False,
|
||||||
|
# "precision": "fp16",
|
||||||
|
# }, # [dp(4)] + [moe_dp(4)]
|
||||||
|
# {
|
||||||
|
# "tp_size": 1,
|
||||||
|
# "pp_size": 2,
|
||||||
|
# "num_microbatches": 2,
|
||||||
|
# "ep_size": 2,
|
||||||
|
# "zero_stage": 1,
|
||||||
|
# "overlap_communication": False,
|
||||||
|
# "precision": "fp32",
|
||||||
|
# }, # [dp(2) + pp(2)] + [moe_pp(2)]
|
||||||
|
# {
|
||||||
|
# "tp_size": 1,
|
||||||
|
# "pp_size": 2,
|
||||||
|
# "ep_size": 2,
|
||||||
|
# "num_microbatches": 2,
|
||||||
|
# "zero_stage": 1,
|
||||||
|
# "overlap_communication": False,
|
||||||
|
# "precision": "fp16",
|
||||||
|
# "initial_scale": 1,
|
||||||
|
# "find_unused_parameters": True,
|
||||||
|
# }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
|
||||||
|
{ # Ulysess + Flash attention
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"ep_size": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"overlap_communication": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"find_unused_parameters": True,
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# "tp_size": 1,
|
||||||
|
# "pp_size": 1,
|
||||||
|
# "ep_size": 2,
|
||||||
|
# "zero_stage": 0,
|
||||||
|
# "overlap_communication": False,
|
||||||
|
# "precision": "fp32",
|
||||||
|
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
|
||||||
|
# {
|
||||||
|
# "tp_size": 1,
|
||||||
|
# "pp_size": 1,
|
||||||
|
# "ep_size": 4,
|
||||||
|
# "overlap_communication": False,
|
||||||
|
# "zero_stage": 0,
|
||||||
|
# "precision": "fp32"
|
||||||
|
# }, # full dp for non-moe and full ep for moe
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_deepseek_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_deepseek(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_deepseek_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_mixtral():
|
||||||
|
spawn(check_deepseek, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_mixtral()
|
Loading…
Reference in New Issue
Block a user