diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 854bbf4f2..468b890ab 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -1,12 +1,18 @@ -from typing import List, Optional +import warnings +from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext @@ -18,6 +24,11 @@ from colossalai.moe._operation import ( all_to_all_uneven, ) from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none @@ -362,7 +373,14 @@ class DeepseekPipelineForwards: next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) # always return dict for imediate stage return { "hidden_states": hidden_states, @@ -479,3 +497,276 @@ class DeepseekPipelineForwards: hidden_states = outputs.get("hidden_states") out["hidden_states"] = hidden_states return out + + +def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + # DeepseekFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + rank = dist.get_rank() + print(f"{rank=}, hidden states:{hidden_states.shape}") + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + rank = dist.get_rank() + print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0 + ) + print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + print( + f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}" + ) + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + return forward + + +def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 04d1dcd41..1e44aba4e 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -7,8 +7,14 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.linear import Linear1D_Row -from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE +from colossalai.shardformer.modeling.deepseek import ( + DeepseekPipelineForwards, + EPDeepseekMoE, + get_deepseek_flash_attention_forward, + get_deepseek_flash_attention_model_forward, +) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] @@ -19,6 +25,13 @@ class DeepseekPolicy(Policy): pass def preprocess(self): + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation + """ + Because transformers library's bug for AutoModel/AutoConfig, who pop “attn_implement” twice from modeling_utils.py and configuration_utils.py. + This bug causes attn_cls to be set to sdpa. Here we assign it to "flash_attention_2". + """ + # self.origin_attn_implement = "flash_attention_2" if self.shard_config.enable_tensor_parallelism: # Resize embedding vocab_size = self.model.config.vocab_size @@ -31,17 +44,61 @@ class DeepseekPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - policy = {} + ATTN_IMPLEMENTATION = { + "eager": "DeepseekAttention", + "flash_attention_2": "DeepseekFlashAttention2", + "sdpa": "DeepseekSdpaAttention", + } + policy = {} + print(f"{self.origin_attn_implement=}") + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) if self.shard_config.enable_sequence_parallelism: if self.pipeline_stage_manager is not None: # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism # if both are enabled, one of them will be ignored raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") - raise NotImplementedError( - "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + print(f"{attn_cls=}") + self.append_or_create_method_replacement( + description={ + "forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, ) - + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_deepseek_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key="DeepseekModel", + ) + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding if self.shard_config.enable_tensor_parallelism: # tensor parallelism for non-moe params assert ( @@ -78,6 +135,16 @@ class DeepseekPolicy(Policy): ), ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key="DeepseekModel", + ) if self.shard_config.ep_group: # expert parallel @@ -105,10 +172,12 @@ class DeepseekPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -119,6 +188,7 @@ class DeepseekPolicy(Policy): description=SubModuleReplacementDescription( suffix="norm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key="DeepseekModel", diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index ac5184065..4adc38619 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -4,6 +4,7 @@ from .blip2 import * from .bloom import * from .chatglm2 import * from .command import * +from .deepseek import * from .falcon import * from .gpt import * from .gptj import * diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py new file mode 100644 index 000000000..b8b446b57 --- /dev/null +++ b/tests/kit/model_zoo/transformers/deepseek.py @@ -0,0 +1,84 @@ +# modified from tests/kit/model_zoo/transformers/mistral.py +import torch +import transformers +from transformers import AutoConfig + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Mixtral +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mixtral_model = lambda x: x[0].mean() +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + + +def init_deepseek(): + + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + hidden_size=32, + intermediate_size=32, + moe_intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=8, + # vocab_size=2200, + first_k_dense_replace=1, + attn_implementation="flash_attention_2", + torch_dtype="float16", + n_routed_experts=8, + trust_remote_code=True, + ) + + if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + print(config) + model = transformers.AutoModel.from_config(config, trust_remote_code=True) + + return model + + +model_zoo.register( + name="transformers_deepseek", + model_fn=init_deepseek, + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mixtral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index c301777f2..6e8ef2da3 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -36,8 +36,8 @@ CHECKED_CONFIG = [ # FOR_WORLD=8 [ # (2, 1, 2, 1, 1), # TODO debug deepseek pp # (2, 1, 2, 2, 1), # TODO debug deepseek pp - (2, 1, 1, 2, 1), - # (2, 1, 1, 1, 2), # TODO support deepseek sp + # (2, 1, 1, 2, 1), + (2, 1, 1, 1, 2), # (2, 1, 4, 1, 1), # TODO debug deepseek pp # (4, 1, 2, 1, 1), # TODO debug deepseek pp ], @@ -69,14 +69,22 @@ def run_zero_with_original_model(config: Tuple[int, ...]): booster = Booster(plugin=plugin) assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" - config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) - config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS - config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2 - config.num_hidden_layers = 2 - config.num_attention_heads = NUM_HEADS - config.num_key_value_heads = NUM_HEADS - config.n_routed_experts = NUM_EXPERTS - config.num_experts_per_tok = TOP_K + # config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=2, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + first_k_dense_replace=1, + attn_implementation="flash_attention_2", + torch_dtype="float16", + n_routed_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + trust_remote_code=True, + ) # init model with the same seed seed_all(10086) diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py new file mode 100644 index 000000000..fdca11005 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py @@ -0,0 +1,231 @@ +# modified from test_shard_mistral.py +import os + +import pytest +import torch +import torch.distributed as dist +from torch.testing import assert_close + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + # TODO: SGD failed for full dp + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD + ) + + org_model = org_model.to(torch.float16) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol) + + # unwrap model + mixtral_model = unwrap_model(org_model, "DeepseekModel", "model") + shard_mixtral_model = unwrap_model(sharded_model, "DeepseekModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + name_to_p = {n: p for n, p in mixtral_model.named_parameters()} + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + # booster.plugin.zero_stage in [1, 2] + booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + rank = dist.get_rank() + for n, p in shard_mixtral_model.named_parameters(): + zero_grad = sharded_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) + continue + assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False) + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 5e-5, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + mixtral_model, + shard_mixtral_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + col_layer_grads = get_grad_tensors_for_check( + mixtral_model, + shard_mixtral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # check grads + check_all_grad_tensors(grads_to_check) + + for n, p in shard_mixtral_model.named_parameters(): + assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + for n, p in shard_mixtral_model.named_parameters(): + assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 2e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + try: + check_weight( + mixtral_model, + shard_mixtral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + except Exception as e: + rank = dist.get_rank() + print(f"{rank=}, Failed config: {test_config}") + raise e + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + # { + # "tp_size": 1, + # "pp_size": 1, + # "num_microbatches": 2, + # "ep_size": 2, + # "zero_stage": 0, + # "overlap_communication": False, + # "precision": "fp16", + # }, # [dp(4)] + [moe_dp(4)] + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "ep_size": 2, + # "zero_stage": 1, + # "overlap_communication": False, + # "precision": "fp32", + # }, # [dp(2) + pp(2)] + [moe_pp(2)] + # { + # "tp_size": 1, + # "pp_size": 2, + # "ep_size": 2, + # "num_microbatches": 2, + # "zero_stage": 1, + # "overlap_communication": False, + # "precision": "fp16", + # "initial_scale": 1, + # "find_unused_parameters": True, + # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "ep_size": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "zero_stage": 1, + "overlap_communication": False, + "precision": "fp16", + "initial_scale": 1, + "find_unused_parameters": True, + }, + # { + # "tp_size": 1, + # "pp_size": 1, + # "ep_size": 2, + # "zero_stage": 0, + # "overlap_communication": False, + # "precision": "fp32", + # }, # [dp(4)] + [ep(2) + moe_tp(2)] + # { + # "tp_size": 1, + # "pp_size": 1, + # "ep_size": 4, + # "overlap_communication": False, + # "zero_stage": 0, + # "precision": "fp32" + # }, # full dp for non-moe and full ep for moe + ], +) +def run_deepseek_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_deepseek(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_deepseek_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_mixtral(): + spawn(check_deepseek, 4) + + +if __name__ == "__main__": + test_mixtral()