mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark
This commit is contained in:
parent
8fd25d6e09
commit
c54c4fcd15
@ -64,13 +64,18 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
):
|
):
|
||||||
pg_param_list = {
|
if dp_process_group is moe_dp_group:
|
||||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
pg_param_list = {
|
||||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
dp_process_group: list(model.parameters()),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
pg_param_list = {
|
||||||
|
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||||
|
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||||
|
}
|
||||||
|
|
||||||
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
|
if len(pg_param_list[moe_dp_group]) == 0:
|
||||||
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
|
raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -407,6 +412,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
and self.enable_sequence_parallelism
|
and self.enable_sequence_parallelism
|
||||||
and self.sequence_parallelism_mode == "all_to_all"
|
and self.sequence_parallelism_mode == "all_to_all"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||||
|
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||||
|
else:
|
||||||
|
dp_group = self.dp_group
|
||||||
|
|
||||||
if use_ddp:
|
if use_ddp:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
||||||
@ -414,17 +426,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
)
|
)
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
|
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
|
||||||
)
|
)
|
||||||
|
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
|
||||||
else:
|
|
||||||
dp_group = self.dp_group
|
|
||||||
|
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
module=model,
|
module=model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
@ -466,6 +472,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
is_zero = True
|
||||||
if self.dp_size <= 1:
|
if self.dp_size <= 1:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||||
|
@ -308,7 +308,7 @@ class EPGradScalerIn(torch.autograd.Function):
|
|||||||
assert len(grad_outputs) == 1
|
assert len(grad_outputs) == 1
|
||||||
grad = grad_outputs[0]
|
grad = grad_outputs[0]
|
||||||
if ctx.ep_size != 1:
|
if ctx.ep_size != 1:
|
||||||
grad = grad * ctx.ep_size
|
grad.mul_(ctx.ep_size)
|
||||||
return grad, None
|
return grad, None
|
||||||
|
|
||||||
|
|
||||||
@ -328,7 +328,7 @@ class EPGradScalerOut(torch.autograd.Function):
|
|||||||
assert len(grad_outputs) == 1
|
assert len(grad_outputs) == 1
|
||||||
grad = grad_outputs[0]
|
grad = grad_outputs[0]
|
||||||
if ctx.ep_size != 1:
|
if ctx.ep_size != 1:
|
||||||
grad = grad / ctx.ep_size
|
grad.div_(ctx.ep_size)
|
||||||
return grad, None
|
return grad, None
|
||||||
|
|
||||||
|
|
||||||
@ -449,7 +449,4 @@ def all_to_all_uneven(
|
|||||||
overlap: bool = False,
|
overlap: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
):
|
):
|
||||||
assert (
|
|
||||||
inputs.requires_grad
|
|
||||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
|
||||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
|
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
|
||||||
|
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.functional as F
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
@ -28,11 +28,13 @@ from colossalai.quantization.fp8 import all_reduce_fp8
|
|||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer._operation import (
|
||||||
all_to_all_comm,
|
all_to_all_comm,
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
|
linear_with_async_comm,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
|
||||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +60,7 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
|||||||
return grad_output, grad_loss
|
return grad_output, grad_loss
|
||||||
|
|
||||||
|
|
||||||
class EPDeepseekMoE(nn.Module):
|
class EPDeepseekMoE(ParallelModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
|
|
||||||
@ -214,6 +216,79 @@ class EPDeepseekMoE(nn.Module):
|
|||||||
return output_hidden_states
|
return output_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekMoEGate_Col(ParallelModule):
|
||||||
|
def parallel_linear(self, hidden_states):
|
||||||
|
assert (
|
||||||
|
hidden_states.shape[-1] == self.weight.shape[-1]
|
||||||
|
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
|
hidden_states.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
output = linear_with_async_comm(
|
||||||
|
hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
|
# All-gather across the partitions.
|
||||||
|
output = gather_forward_split_backward(
|
||||||
|
output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
bsz, seq_len, h = hidden_states.shape
|
||||||
|
### compute gating score
|
||||||
|
hidden_states = hidden_states.view(-1, h)
|
||||||
|
logits = self.parallel_linear(hidden_states)
|
||||||
|
if self.scoring_func == "softmax":
|
||||||
|
scores = logits.softmax(dim=-1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||||
|
|
||||||
|
### select top-k experts
|
||||||
|
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
### norm gate to sum 1
|
||||||
|
if self.top_k > 1 and self.norm_topk_prob:
|
||||||
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
|
topk_weight = topk_weight / denominator
|
||||||
|
|
||||||
|
### expert-level computation auxiliary loss
|
||||||
|
if self.training and self.alpha > 0.0:
|
||||||
|
scores_for_aux = scores
|
||||||
|
aux_topk = self.top_k
|
||||||
|
# always compute aux loss based on the naive greedy topk method
|
||||||
|
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||||
|
if self.seq_aux:
|
||||||
|
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||||
|
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||||
|
ce.scatter_add_(
|
||||||
|
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
|
||||||
|
).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||||
|
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||||
|
else:
|
||||||
|
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||||
|
ce = mask_ce.float().mean(0)
|
||||||
|
Pi = scores_for_aux.mean(0)
|
||||||
|
fi = ce * self.n_routed_experts
|
||||||
|
aux_loss = (Pi * fi).sum() * self.alpha
|
||||||
|
else:
|
||||||
|
aux_loss = None
|
||||||
|
|
||||||
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module, process_group: ProcessGroup, config, gather_output, fp8_communication
|
||||||
|
) -> "DeepseekMoEGate_Col":
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
module.process_group = process_group
|
||||||
|
module.fp8_communication = fp8_communication
|
||||||
|
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||||
|
sharded_tensor_to_existing_param(sharded_weight, module.weight)
|
||||||
|
module.__class__ = DeepseekMoEGate_Col
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
class DeepseekPipelineForwards:
|
class DeepseekPipelineForwards:
|
||||||
"""
|
"""
|
||||||
This class serves as a micro library for forward function substitution of Llama models
|
This class serves as a micro library for forward function substitution of Llama models
|
||||||
|
@ -36,7 +36,7 @@ from colossalai.shardformer.layer._operation import (
|
|||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||||
@ -49,7 +49,7 @@ if is_flash_attn_2_available():
|
|||||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||||
|
|
||||||
|
|
||||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
|||||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||||
from colossalai.shardformer.modeling.deepseek import (
|
from colossalai.shardformer.modeling.deepseek import (
|
||||||
|
DeepseekMoEGate_Col,
|
||||||
DeepseekPipelineForwards,
|
DeepseekPipelineForwards,
|
||||||
EPDeepseekMoE,
|
EPDeepseekMoE,
|
||||||
get_deepseek_flash_attention_forward,
|
get_deepseek_flash_attention_forward,
|
||||||
@ -56,16 +57,24 @@ class DeepseekPolicy(Policy):
|
|||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
tp_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
# modified for both SP and TP
|
||||||
|
num_q_heads = self.model.config.num_attention_heads
|
||||||
|
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
num_q_heads //= sp_size
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
"num_heads": num_q_heads,
|
||||||
}
|
}
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
num_kv_heads //= sp_size
|
||||||
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||||
@ -97,6 +106,7 @@ class DeepseekPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
if self.tie_weight:
|
if self.tie_weight:
|
||||||
embedding_cls = PaddingEmbedding
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# tensor parallelism for non-moe params
|
# tensor parallelism for non-moe params
|
||||||
assert (
|
assert (
|
||||||
@ -107,10 +117,15 @@ class DeepseekPolicy(Policy):
|
|||||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
||||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
|
||||||
// self.shard_config.tensor_parallel_size,
|
|
||||||
}
|
}
|
||||||
|
num_q_heads //= tp_size
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads": num_q_heads,
|
||||||
|
}
|
||||||
|
if num_kv_heads:
|
||||||
|
num_kv_heads //= tp_size
|
||||||
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
|
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
@ -135,8 +150,19 @@ class DeepseekPolicy(Policy):
|
|||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate",
|
||||||
|
target_module=DeepseekMoEGate_Col,
|
||||||
|
kwargs={
|
||||||
|
"gather_output": True,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"config": self.model.config,
|
||||||
|
},
|
||||||
|
ignore_if_not_exist=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
|
@ -51,12 +51,20 @@ class MixtralPolicy(Policy):
|
|||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
tp_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
# modified for both SP and TP
|
||||||
|
num_q_heads = self.model.config.num_attention_heads
|
||||||
|
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
|
||||||
|
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
num_q_heads //= sp_size
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
"num_heads": num_q_heads,
|
||||||
}
|
}
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
num_kv_heads //= sp_size
|
||||||
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
@ -101,12 +109,14 @@ class MixtralPolicy(Policy):
|
|||||||
assert (
|
assert (
|
||||||
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
|
num_q_heads //= tp_size
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": num_q_heads,
|
||||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
|
||||||
// self.shard_config.tensor_parallel_size,
|
|
||||||
}
|
}
|
||||||
|
if num_kv_heads:
|
||||||
|
num_kv_heads //= tp_size
|
||||||
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
@ -131,7 +141,7 @@ class MixtralPolicy(Policy):
|
|||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription( # or replicate?
|
SubModuleReplacementDescription(
|
||||||
suffix="block_sparse_moe.gate",
|
suffix="block_sparse_moe.gate",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
||||||
|
271
examples/language/deepseek/benchmark.py
Normal file
271
examples/language/deepseek/benchmark.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
# modified from mixtral benchmark
|
||||||
|
import argparse
|
||||||
|
import resource
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from data_utils import RandomDataset
|
||||||
|
from model_utils import format_numel_str, get_model_numel
|
||||||
|
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
# ==============================
|
||||||
|
# Constants
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
# We have lots of llamas for your choice!
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
"100m": lambda: AutoConfig.from_pretrained(
|
||||||
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
num_hidden_layers=1,
|
||||||
|
num_attention_heads=32,
|
||||||
|
intermediate_size=512,
|
||||||
|
moe_intermediate_size=128,
|
||||||
|
hidden_size=512,
|
||||||
|
n_routed_experts=8,
|
||||||
|
n_shared_experts=4,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
|
"7b": lambda: AutoConfig.from_pretrained(
|
||||||
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
num_hidden_layers=13,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
|
"14b": lambda: AutoConfig.from_pretrained(
|
||||||
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
num_hidden_layers=26,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--plugin",
|
||||||
|
choices=["3d"],
|
||||||
|
default="3d",
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||||
|
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||||
|
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||||
|
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||||
|
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||||
|
parser.add_argument(
|
||||||
|
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
|
||||||
|
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
|
||||||
|
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||||
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||||
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||||
|
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
||||||
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||||
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nsys",
|
||||||
|
action="store_true",
|
||||||
|
help="Use nsys for profiling. \
|
||||||
|
You should put something like this before colossalai launch: \
|
||||||
|
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
|
||||||
|
)
|
||||||
|
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||||
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sp_mode",
|
||||||
|
default="all_to_all",
|
||||||
|
choices=["all_to_all"],
|
||||||
|
help="Sequence parallelism mode",
|
||||||
|
)
|
||||||
|
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
colossalai.launch_from_torch()
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||||
|
hybrid_kwargs = (
|
||||||
|
{
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
|
),
|
||||||
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
}
|
||||||
|
if args.custom_ckpt
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
if args.plugin == "3d":
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
ep_size=args.ep,
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
pp_style=args.pp_style,
|
||||||
|
num_model_chunks=args.n_chunks,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
|
enable_flash_attention=args.xformers,
|
||||||
|
microbatch_size=args.mbs,
|
||||||
|
precision="bf16",
|
||||||
|
enable_metadata_cache=not args.no_cache,
|
||||||
|
overlap_allgather=args.overlap_allgather,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
**hybrid_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Dataset and Dataloader
|
||||||
|
# ==============================
|
||||||
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||||
|
|
||||||
|
config = MODEL_CONFIGS[args.config]()
|
||||||
|
|
||||||
|
torch.cuda.manual_seed(42)
|
||||||
|
|
||||||
|
dataset = RandomDataset(
|
||||||
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Model and Optimizer
|
||||||
|
# ==============================
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||||
|
if isinstance(plugin, MoeHybridParallelPlugin)
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
|
with init_ctx:
|
||||||
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16)
|
||||||
|
|
||||||
|
if args.grad_checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
model_numel = get_model_numel(model)
|
||||||
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
|
performance_evaluator = PerformanceEvaluator(
|
||||||
|
model_numel,
|
||||||
|
model.config.num_hidden_layers,
|
||||||
|
model.config.hidden_size,
|
||||||
|
model.config.vocab_size,
|
||||||
|
args.grad_checkpoint,
|
||||||
|
args.ignore_steps,
|
||||||
|
dp_world_size=dp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = HybridAdam(model.parameters())
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_profile_context(
|
||||||
|
args.profile,
|
||||||
|
args.ignore_steps,
|
||||||
|
1, # avoid creating massive log files
|
||||||
|
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||||
|
nsys=args.nsys,
|
||||||
|
) as prof: # , distributed_debug_mode(10, enable=True):
|
||||||
|
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||||
|
prof.step()
|
||||||
|
print(f"rank {dist.get_rank()} step {step} passed")
|
||||||
|
else:
|
||||||
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs[0]
|
||||||
|
del outputs # free memory
|
||||||
|
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
performance_evaluator.on_step_end(**batch)
|
||||||
|
prof.step()
|
||||||
|
|
||||||
|
performance_evaluator.on_fit_end()
|
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
examples/language/deepseek/data_utils.py
Symbolic link
1
examples/language/deepseek/data_utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../data_utils.py
|
1
examples/language/deepseek/model_utils.py
Symbolic link
1
examples/language/deepseek/model_utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../model_utils.py
|
1
examples/language/deepseek/performance_evaluator.py
Symbolic link
1
examples/language/deepseek/performance_evaluator.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../performance_evaluator.py
|
0
examples/language/deepseek/test_ci.sh
Executable file
0
examples/language/deepseek/test_ci.sh
Executable file
@ -105,7 +105,7 @@ def main():
|
|||||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
parser.add_argument("--no_cache", action="store_true")
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
parser.add_argument("--use_fp8", action="store_true")
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
parser.add_argument("--overlap_allgather", action="store_true")
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sp_mode",
|
"--sp_mode",
|
||||||
@ -151,6 +151,7 @@ def main():
|
|||||||
max_prefetch=args.prefetch_num,
|
max_prefetch=args.prefetch_num,
|
||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
@ -164,6 +165,7 @@ def main():
|
|||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "fsdp":
|
elif args.plugin == "fsdp":
|
||||||
if use_empty_init:
|
if use_empty_init:
|
||||||
@ -224,6 +226,7 @@ def main():
|
|||||||
enable_metadata_cache=not args.no_cache,
|
enable_metadata_cache=not args.no_cache,
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d_cpu":
|
elif args.plugin == "3d_cpu":
|
||||||
@ -241,6 +244,7 @@ def main():
|
|||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=args.overlap,
|
overlap_p2p=args.overlap,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
259
examples/language/mixtral/benchmark.py
Normal file
259
examples/language/mixtral/benchmark.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
# modified from llama benchmark
|
||||||
|
import argparse
|
||||||
|
import resource
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from data_utils import RandomDataset
|
||||||
|
from model_utils import format_numel_str, get_model_numel
|
||||||
|
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
# ==============================
|
||||||
|
# Constants
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
# We have lots of llamas for your choice!
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
"100m": MixtralConfig(
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=32,
|
||||||
|
intermediate_size=768,
|
||||||
|
hidden_size=768,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
),
|
||||||
|
"7b": MixtralConfig(
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
),
|
||||||
|
"14b": MixtralConfig(
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
num_hidden_layers=10,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--plugin",
|
||||||
|
choices=["3d"],
|
||||||
|
default="3d",
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||||
|
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||||
|
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||||
|
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||||
|
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||||
|
parser.add_argument(
|
||||||
|
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
|
||||||
|
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
|
||||||
|
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||||
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||||
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||||
|
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
||||||
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||||
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nsys",
|
||||||
|
action="store_true",
|
||||||
|
help="Use nsys for profiling. \
|
||||||
|
You should put something like this before colossalai launch: \
|
||||||
|
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
|
||||||
|
)
|
||||||
|
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||||
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sp_mode",
|
||||||
|
default="all_to_all",
|
||||||
|
choices=["all_to_all"],
|
||||||
|
help="Sequence parallelism mode",
|
||||||
|
)
|
||||||
|
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
colossalai.launch_from_torch()
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||||
|
hybrid_kwargs = (
|
||||||
|
{
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
|
),
|
||||||
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
}
|
||||||
|
if args.custom_ckpt
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
if args.plugin == "3d":
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
ep_size=args.ep,
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
pp_style=args.pp_style,
|
||||||
|
num_model_chunks=args.n_chunks,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
|
enable_flash_attention=args.xformers,
|
||||||
|
microbatch_size=args.mbs,
|
||||||
|
precision="bf16",
|
||||||
|
enable_metadata_cache=not args.no_cache,
|
||||||
|
overlap_allgather=args.overlap_allgather,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
**hybrid_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Dataset and Dataloader
|
||||||
|
# ==============================
|
||||||
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||||
|
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = MixtralConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
torch.cuda.manual_seed(42)
|
||||||
|
|
||||||
|
dataset = RandomDataset(
|
||||||
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Model and Optimizer
|
||||||
|
# ==============================
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||||
|
if isinstance(plugin, MoeHybridParallelPlugin)
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
|
with init_ctx:
|
||||||
|
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
||||||
|
|
||||||
|
if args.grad_checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
model_numel = get_model_numel(model)
|
||||||
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
|
performance_evaluator = PerformanceEvaluator(
|
||||||
|
model_numel,
|
||||||
|
model.config.num_hidden_layers,
|
||||||
|
model.config.hidden_size,
|
||||||
|
model.config.vocab_size,
|
||||||
|
args.grad_checkpoint,
|
||||||
|
args.ignore_steps,
|
||||||
|
dp_world_size=dp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = HybridAdam(model.parameters())
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_profile_context(
|
||||||
|
args.profile,
|
||||||
|
args.ignore_steps,
|
||||||
|
1, # avoid creating massive log files
|
||||||
|
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||||
|
nsys=args.nsys,
|
||||||
|
) as prof:
|
||||||
|
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||||
|
prof.step()
|
||||||
|
else:
|
||||||
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs[0]
|
||||||
|
del outputs # free memory
|
||||||
|
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
performance_evaluator.on_step_end(**batch)
|
||||||
|
prof.step()
|
||||||
|
performance_evaluator.on_fit_end()
|
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
examples/language/mixtral/data_utils.py
Symbolic link
1
examples/language/mixtral/data_utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../data_utils.py
|
1
examples/language/mixtral/model_utils.py
Symbolic link
1
examples/language/mixtral/model_utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../model_utils.py
|
1
examples/language/mixtral/performance_evaluator.py
Symbolic link
1
examples/language/mixtral/performance_evaluator.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../performance_evaluator.py
|
0
examples/language/mixtral/test_ci.sh
Executable file
0
examples/language/mixtral/test_ci.sh
Executable file
@ -1,4 +1,12 @@
|
|||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from time import sleep
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
|
||||||
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||||
@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|||||||
return torch.allclose(a, b, rtol=rtol, atol=atol)
|
return torch.allclose(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
def check_model_equal(model1, model2):
|
def check_model_equal(model1, model2, dtype):
|
||||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||||
assert_loose_close(p1, p2, p1.dtype)
|
assert_loose_close(p1, p2, dtype, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
|
||||||
|
if enable:
|
||||||
|
assert (
|
||||||
|
os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
|
||||||
|
), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
|
||||||
|
if funcs_to_patch is None:
|
||||||
|
funcs_to_patch = [
|
||||||
|
dist.all_reduce,
|
||||||
|
dist.all_reduce_coalesced,
|
||||||
|
dist.all_gather,
|
||||||
|
dist.all_gather_coalesced,
|
||||||
|
dist.all_gather_into_tensor,
|
||||||
|
dist.all_to_all,
|
||||||
|
dist.all_to_all_single,
|
||||||
|
dist.reduce_scatter,
|
||||||
|
]
|
||||||
|
|
||||||
|
original_funcs = {}
|
||||||
|
patched_funcs = {}
|
||||||
|
|
||||||
|
def make_patched(func):
|
||||||
|
def patched_func(*args, **kwargs):
|
||||||
|
stack = traceback.format_stack()
|
||||||
|
|
||||||
|
def format_node(node):
|
||||||
|
if isinstance(node, torch.Tensor):
|
||||||
|
return f"{node.shape}"
|
||||||
|
elif isinstance(node, list):
|
||||||
|
return f"[{', '.join([format_node(n) for n in node])}]"
|
||||||
|
|
||||||
|
return str(node)
|
||||||
|
|
||||||
|
args_str, kwargs_str = tree_map(format_node, (args, kwargs))
|
||||||
|
en = len(stack) - 1
|
||||||
|
st = max(0, en - num_stacks)
|
||||||
|
dist.barrier()
|
||||||
|
sleep(0.001 * dist.get_rank())
|
||||||
|
print(
|
||||||
|
f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
|
||||||
|
)
|
||||||
|
dist.barrier()
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return patched_func
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
for func in funcs_to_patch:
|
||||||
|
original_funcs[func.__name__] = getattr(dist, func.__name__)
|
||||||
|
patched_funcs[func.__name__] = make_patched(func)
|
||||||
|
setattr(dist, func.__name__, patched_funcs[func.__name__])
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for func_name, original_func in original_funcs.items():
|
||||||
|
setattr(dist, func_name, original_func)
|
||||||
|
@ -130,7 +130,7 @@ def check_moe_checkpoint(test_config):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
|
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
|
||||||
check_model_equal(orig_model, saved_model)
|
check_model_equal(orig_model, saved_model, dtype=dtype)
|
||||||
saved_model.save_pretrained(hf_model_dir)
|
saved_model.save_pretrained(hf_model_dir)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
# check load model
|
# check load model
|
||||||
@ -138,7 +138,7 @@ def check_moe_checkpoint(test_config):
|
|||||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||||
booster.load_model(new_model, hf_model_dir)
|
booster.load_model(new_model, hf_model_dir)
|
||||||
check_model_equal(model, new_model)
|
check_model_equal(model, new_model, dtype=dtype)
|
||||||
|
|
||||||
# check save optimizer
|
# check save optimizer
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -12,43 +12,25 @@ from transformers import AutoConfig, AutoModel
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster.booster import Booster
|
from colossalai.booster.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||||
|
|
||||||
NUM_BATCH = 8
|
NUM_BATCH = 8
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
HIDDEN_SIZE_PER_HEAD = 4
|
HIDDEN_SIZE_PER_HEAD = 4
|
||||||
NUM_HEADS = 4
|
NUM_HEADS = 8
|
||||||
TOP_K = 2
|
TOP_K = 2
|
||||||
|
|
||||||
|
|
||||||
CHECKED_CONFIG = [ # FOR_WORLD=4
|
def run_deepseek_commom(config: Tuple[int, ...]):
|
||||||
(1, 4, 1, 1, 1),
|
Randomizer.reset_index()
|
||||||
(1, 1, 4, 1, 1),
|
|
||||||
(1, 1, 1, 4, 1),
|
|
||||||
(1, 1, 1, 1, 4),
|
|
||||||
(0, 1, 4, 1, 1),
|
|
||||||
(0, 1, 1, 4, 1),
|
|
||||||
(0, 1, 1, 1, 4),
|
|
||||||
(1, 2, 1, 1, 1),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
|
||||||
"config",
|
|
||||||
[
|
|
||||||
(1, 2, 2, 1, 1),
|
|
||||||
(1, 2, 1, 2, 1),
|
|
||||||
(1, 2, 1, 1, 2),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
|
||||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
dtype, precision = torch.float16, "fp16"
|
dtype, precision = torch.bfloat16, "bf16"
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
@ -60,11 +42,11 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||||||
zero_stage=stage,
|
zero_stage=stage,
|
||||||
enable_sequence_parallelism=sp_size > 1,
|
enable_sequence_parallelism=sp_size > 1,
|
||||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||||
enable_flash_attention=sp_size > 1,
|
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=True,
|
||||||
|
enable_flash_attention=True,
|
||||||
)
|
)
|
||||||
dp_size = plugin.dp_size
|
dp_size = plugin.dp_size
|
||||||
|
|
||||||
@ -171,7 +153,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
|
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
|
||||||
check_model_equal(torch_model, saved_model)
|
check_model_equal(torch_model, saved_model, dtype=dtype)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
if rank == world_size - 1:
|
if rank == world_size - 1:
|
||||||
@ -180,17 +162,77 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||||||
print(f"rank {dist.get_rank()} test passed")
|
print(f"rank {dist.get_rank()} test passed")
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
@parameterize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||||
|
(0, 1, 4, 1, 1),
|
||||||
|
(0, 1, 1, 4, 1),
|
||||||
|
(0, 1, 2, 2, 1),
|
||||||
|
# zero 1
|
||||||
|
(1, 4, 1, 1, 1),
|
||||||
|
(1, 1, 4, 1, 1),
|
||||||
|
(1, 1, 1, 4, 1),
|
||||||
|
(1, 2, 1, 1, 2),
|
||||||
|
# zero 2
|
||||||
|
(2, 4, 1, 1, 1),
|
||||||
|
(2, 1, 4, 1, 1),
|
||||||
|
(2, 1, 1, 4, 1),
|
||||||
|
(2, 2, 1, 1, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_deepseek_test(config: Tuple[int, ...]):
|
||||||
|
run_deepseek_commom(config)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||||
|
(0, 1, 2, 4, 1),
|
||||||
|
(0, 1, 4, 2, 1),
|
||||||
|
(0, 1, 1, 4, 1),
|
||||||
|
(0, 1, 4, 1, 1),
|
||||||
|
# zero 1:
|
||||||
|
(1, 2, 1, 1, 2),
|
||||||
|
(1, 2, 1, 4, 1),
|
||||||
|
(1, 1, 1, 2, 2),
|
||||||
|
(1, 2, 2, 2, 1),
|
||||||
|
# zero 2
|
||||||
|
(2, 2, 1, 1, 2),
|
||||||
|
(2, 2, 1, 4, 1),
|
||||||
|
(2, 1, 1, 2, 2),
|
||||||
|
(2, 2, 2, 2, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_deepseek_3d_test(config: Tuple[int, ...]):
|
||||||
|
run_deepseek_commom(config)
|
||||||
|
|
||||||
|
|
||||||
|
def check_deepseek(rank, world_size, port):
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_zero_with_original_model()
|
run_deepseek_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_deepseek_3d(rank, world_size, port):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_deepseek_3d_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_deepseek(world_size):
|
def test_deepseek(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(check_deepseek, world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@pytest.mark.parametrize("world_size", [8])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_deepseek_3d(world_size):
|
||||||
|
spawn(check_deepseek_3d, world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_deepseek(world_size=4)
|
test_deepseek(world_size=8)
|
||||||
|
test_deepseek_3d(world_size=8)
|
||||||
|
@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster.booster import Booster
|
from colossalai.booster.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||||
|
|
||||||
NUM_BATCH = 8
|
NUM_BATCH = 8
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
HIDDEN_SIZE_PER_HEAD = 4
|
HIDDEN_SIZE_PER_HEAD = 4
|
||||||
NUM_HEADS = 4
|
NUM_HEADS = 8
|
||||||
TOP_K = 1
|
TOP_K = 2
|
||||||
|
|
||||||
CHECKED_CONFIG = [ # FOR WORLD=4
|
|
||||||
(0, 1, 4, 1, 1),
|
|
||||||
(0, 1, 1, 4, 1),
|
|
||||||
(0, 1, 1, 1, 4),
|
|
||||||
(1, 4, 1, 1, 1),
|
|
||||||
(1, 1, 4, 1, 1),
|
|
||||||
(1, 1, 1, 4, 1),
|
|
||||||
(1, 1, 1, 1, 4),
|
|
||||||
(1, 2, 1, 1, 1),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
def run_mixtral_commom(config: Tuple[int, ...]):
|
||||||
"config",
|
Randomizer.reset_index()
|
||||||
[
|
|
||||||
(1, 2, 2, 1, 1),
|
|
||||||
(1, 2, 1, 2, 1),
|
|
||||||
(1, 2, 1, 1, 2),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
|
||||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
dtype, precision = torch.float16, "fp16"
|
dtype, precision = torch.bfloat16, "bf16"
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||||
check_model_equal(torch_model, saved_model)
|
check_model_equal(torch_model, saved_model, dtype=dtype)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
if rank == world_size - 1:
|
if rank == world_size - 1:
|
||||||
@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||||||
print(f"rank {dist.get_rank()} test passed")
|
print(f"rank {dist.get_rank()} test passed")
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
@parameterize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||||
|
(0, 1, 4, 1, 1),
|
||||||
|
(0, 1, 1, 4, 1),
|
||||||
|
(0, 1, 2, 2, 1),
|
||||||
|
# zero 1
|
||||||
|
(1, 4, 1, 1, 1),
|
||||||
|
(1, 1, 4, 1, 1),
|
||||||
|
(1, 1, 1, 4, 1),
|
||||||
|
(1, 2, 1, 1, 2),
|
||||||
|
# zero 2
|
||||||
|
(2, 4, 1, 1, 1),
|
||||||
|
(2, 1, 4, 1, 1),
|
||||||
|
(2, 1, 1, 4, 1),
|
||||||
|
(2, 2, 1, 1, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_mixtral_test(config: Tuple[int, ...]):
|
||||||
|
run_mixtral_commom(config)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||||
|
(0, 1, 2, 4, 1),
|
||||||
|
(0, 1, 4, 2, 1),
|
||||||
|
(0, 1, 1, 4, 1),
|
||||||
|
(0, 1, 4, 1, 1),
|
||||||
|
# zero 1:
|
||||||
|
(1, 2, 1, 1, 2),
|
||||||
|
(1, 2, 1, 4, 1),
|
||||||
|
(1, 1, 1, 2, 2),
|
||||||
|
(1, 2, 2, 2, 1),
|
||||||
|
# zero 2
|
||||||
|
(2, 2, 1, 1, 2),
|
||||||
|
(2, 2, 1, 4, 1),
|
||||||
|
(2, 1, 1, 2, 2),
|
||||||
|
(2, 2, 2, 2, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_mixtral_3d_test(config: Tuple[int, ...]):
|
||||||
|
print(f"{config=}")
|
||||||
|
run_mixtral_commom(config)
|
||||||
|
|
||||||
|
|
||||||
|
def check_mixtral(rank, world_size, port):
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_zero_with_original_model()
|
run_mixtral_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_mixtral_3d(rank, world_size, port):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_mixtral_3d_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mixtral(world_size):
|
def test_mixtral(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(check_mixtral, world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@pytest.mark.parametrize("world_size", [8])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_mixtral_3d(world_size):
|
||||||
|
spawn(check_mixtral_3d, world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mixtral(world_size=4)
|
test_mixtral(world_size=8)
|
||||||
|
test_mixtral_3d(world_size=8)
|
||||||
|
Loading…
Reference in New Issue
Block a user