mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commitdf705a5210
. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit58ad76d466
. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
404 lines
16 KiB
Python
404 lines
16 KiB
Python
import dataclasses
|
|
import math
|
|
from typing import Any, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
|
from colossalai.moe.load_balance import LoadBalancer
|
|
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
|
from colossalai.shardformer.layer.moe import MLPExperts
|
|
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
|
|
|
|
|
|
class SparseMLP(nn.Module):
|
|
"""A class for users to create MoE modules in their models.
|
|
|
|
Args:
|
|
dim_model (int): Hidden dimension of training model
|
|
num_experts (int): The number experts
|
|
top_k (int, optional): The number of experts for dispatchment of each token
|
|
parallel (str): parallel mode. Should be "EP", "TP" or None
|
|
capacity_factor_train (float, optional): Capacity factor in routing during training
|
|
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert
|
|
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
|
'Jitter' can be found in `Switch Transformer paper`_.
|
|
'Gaussian' can be found in `ViT-MoE paper`_.
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
|
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
|
More information can be found in `Microsoft paper`_.
|
|
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
|
|
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
|
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
|
expert_args (optional): The args of expert when no instance is given
|
|
|
|
.. _Switch Transformer paper:
|
|
https://arxiv.org/abs/2101.03961
|
|
.. _ViT-MoE paper:
|
|
https://arxiv.org/abs/2106.05974
|
|
.. _Microsoft paper:
|
|
https://arxiv.org/abs/2201.05596
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
router_top_k: int = 1,
|
|
parallel: str = "EP",
|
|
router_loss: bool = True,
|
|
router_norm: bool = False,
|
|
router_capacity_factor_train: float = 1.25,
|
|
router_capacity_factor_eval: float = 2.0,
|
|
router_min_capacity: int = 4,
|
|
router_noisy_policy: Optional[str] = None,
|
|
router_drop_tks: bool = True,
|
|
mlp_activation: Optional[str] = None,
|
|
mlp_gated: bool = False,
|
|
enable_load_balance: bool = False,
|
|
load_balance_tolerance: float = 0.1,
|
|
load_balance_beam_width: int = 8,
|
|
load_balance_group_swap_factor: float = 0.4,
|
|
enable_kernel: bool = False,
|
|
enable_comm_overlap: bool = False,
|
|
enable_hierarchical_comm: bool = True,
|
|
return_gate_logits: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_experts = num_experts
|
|
self.gated = mlp_gated
|
|
self.return_gate_logits = return_gate_logits
|
|
self.enable_kernel = enable_kernel
|
|
self.enable_comm_overlap = enable_comm_overlap
|
|
# self.expert_parallel = MOE_MANAGER.get_parallel()
|
|
assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None"
|
|
self.parallel = parallel
|
|
self.router_loss = router_loss
|
|
self.router_norm = router_norm
|
|
|
|
# moe router
|
|
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
|
router_cls = get_router_cls(router_top_k)
|
|
self.topk = router_top_k
|
|
self.router: MoeRouter = router_cls(
|
|
capacity_factor_train=router_capacity_factor_train,
|
|
capacity_factor_eval=router_capacity_factor_eval,
|
|
min_capacity=router_min_capacity,
|
|
noisy_func=noisy_func,
|
|
drop_tks=router_drop_tks,
|
|
)
|
|
|
|
# gate
|
|
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
|
|
|
|
# moe experts
|
|
self.experts = MLPExperts(
|
|
num_experts=self.num_experts,
|
|
expert_parallel=self.parallel,
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=self.intermediate_size,
|
|
activation=mlp_activation,
|
|
gated=mlp_gated,
|
|
use_kernel=self.enable_kernel,
|
|
)
|
|
|
|
# get parallel settings
|
|
if self.parallel is not None:
|
|
self.ep_group = get_ep_group(self.experts)
|
|
self.ep_size = get_ep_size(self.experts)
|
|
self.ep_hierarchical_group = None
|
|
if enable_hierarchical_comm:
|
|
# TODO: move to plugin
|
|
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
|
|
get_ep_group_ranks(self.experts)
|
|
)
|
|
self.dp_group = get_dp_group(self.experts)
|
|
else:
|
|
self.ep_group = None
|
|
self.dp_group = None
|
|
self.num_local_experts = self.experts.num_local_experts
|
|
|
|
# load balance
|
|
self.enable_load_balance = enable_load_balance
|
|
if self.enable_load_balance == True:
|
|
self.load_balancer = LoadBalancer(
|
|
experts=self.experts,
|
|
gate=self.gate_weight,
|
|
local_expert_num=self.num_local_experts,
|
|
expert_num=self.num_experts,
|
|
ep_group=self.ep_group,
|
|
dp_group=self.dp_group,
|
|
tolerance=load_balance_tolerance,
|
|
beam_width=load_balance_beam_width,
|
|
group_swap_factor=load_balance_group_swap_factor,
|
|
)
|
|
|
|
# init param
|
|
self.reset_parameters()
|
|
|
|
@torch.no_grad()
|
|
def reset_parameters(self):
|
|
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
|
|
"""
|
|
# reshape the input tokens
|
|
tokens = inputs.reshape(-1, self.hidden_size)
|
|
|
|
# the data type of the inputs in the gating should be fp32
|
|
gate_logits = F.linear(tokens, self.gate_weight)
|
|
gate_output = gate_logits.to(torch.float)
|
|
|
|
# update expert load
|
|
if self.enable_load_balance == True:
|
|
with torch.no_grad():
|
|
# TODO: optimize computation
|
|
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
|
|
# TODO: bincount introduces synchronize, fix it
|
|
expert_load = torch.bincount(expert_load.view(-1))
|
|
self.load_balancer.update_load(expert_load)
|
|
|
|
# the result from the router
|
|
used_capacity, *route_result_list = self.router(
|
|
inputs=gate_output,
|
|
use_kernel=self.enable_kernel,
|
|
ep_group=self.ep_group,
|
|
use_loss=self.router_loss,
|
|
use_norm=self.router_norm,
|
|
)
|
|
|
|
# dispatch_data: (num_experts, capacity, hidden_size)
|
|
if self.enable_kernel:
|
|
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
|
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
|
|
else:
|
|
sec_mask_f = route_result_list[1].type_as(inputs)
|
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
|
|
|
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
|
if self.parallel == "EP":
|
|
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
|
elif self.parallel == "TP":
|
|
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
|
elif self.parallel is None:
|
|
expert_output = self._local_process(dispatch_data)
|
|
else:
|
|
raise NotImplementedError(
|
|
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
|
|
)
|
|
|
|
if self.enable_kernel:
|
|
expert_output = expert_output.reshape(-1, self.hidden_size)
|
|
ans = MoeCombine.apply(expert_output, *route_result_list)
|
|
else:
|
|
combine_weights = route_result_list[0].type_as(inputs)
|
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
|
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
|
ans = torch.matmul(combine_weights, expert_output)
|
|
|
|
ans = ans.reshape(inputs.shape)
|
|
|
|
if self.return_gate_logits:
|
|
return ans, gate_logits
|
|
else:
|
|
return ans
|
|
|
|
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
|
expert_in = expert_in.unsqueeze(0)
|
|
expert_out = self.experts(expert_in)
|
|
return expert_out
|
|
|
|
def _ep_process(
|
|
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
|
) -> torch.Tensor:
|
|
"""
|
|
Expert Parallel
|
|
|
|
Args:
|
|
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
|
|
|
|
Returns:
|
|
torch.Tensor: (num_experts, capacity, hidden_size)
|
|
"""
|
|
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
|
if self.ep_hierarchical_group is not None:
|
|
expert_input = HierarchicalAllToAll.apply(
|
|
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
|
|
)
|
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
|
expert_output = self.experts(expert_input)
|
|
expert_output = HierarchicalAllToAll.apply(
|
|
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
|
|
)
|
|
return expert_output
|
|
else:
|
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
|
expert_output = self.experts(expert_input)
|
|
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
|
|
return expert_output
|
|
else:
|
|
|
|
@dataclasses.dataclass
|
|
class Capsule:
|
|
data: torch.Tensor
|
|
handle: Any = None
|
|
|
|
NUM_CHUNK = 4
|
|
NUM_STAGES = 4
|
|
|
|
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
|
|
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
|
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
|
dispatch_data = dispatch_data.reshape(*input_shape)
|
|
chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
|
|
output = torch.empty_like(dispatch_data)
|
|
|
|
offset = 0
|
|
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
|
|
|
|
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
|
if expert_out is not None:
|
|
expert_out.handle.wait()
|
|
output[:, :, offset : offset + chunk_size, :] = expert_out.data
|
|
offset += chunk_size
|
|
expert_out = None
|
|
|
|
# all2all last output
|
|
if _expert_out is not None:
|
|
expert_out = Capsule(
|
|
*AllToAll.apply(_expert_out.data, self.ep_group, True),
|
|
)
|
|
_expert_out = None
|
|
|
|
# all2all next input
|
|
if 0 <= i < NUM_CHUNK:
|
|
_expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
|
|
|
|
# compute
|
|
if expert_in is not None:
|
|
expert_in.handle.wait()
|
|
_expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
|
|
expert_in = None
|
|
|
|
if _expert_in is not None:
|
|
expert_in = _expert_in
|
|
_expert_in = None
|
|
|
|
return output
|
|
|
|
def _tp_process(
|
|
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
|
) -> torch.Tensor:
|
|
"""
|
|
without overlap:
|
|
| C |
|
|
| A | | R |
|
|
|
|
with overlap:
|
|
| C1 || C2 || C3 || C4 |
|
|
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |
|
|
|
|
where C is computation, A is all gather, R is reduce scatter.
|
|
|
|
Args:
|
|
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
|
|
|
|
Returns:
|
|
torch.Tensor: (num_experts, capacity, hidden_size)
|
|
"""
|
|
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
|
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
|
|
expert_out = self.experts(expert_in)
|
|
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
|
|
return expert_out
|
|
else:
|
|
|
|
@dataclasses.dataclass
|
|
class Capsule:
|
|
data: torch.Tensor
|
|
handle: Any
|
|
indices: Tuple
|
|
|
|
NUM_CHUNK = 4
|
|
NUM_STAGES = 4
|
|
|
|
assert (
|
|
dispatch_data.shape[0] % NUM_CHUNK == 0
|
|
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
|
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
|
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
|
output = torch.empty_like(dispatch_data)
|
|
|
|
def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
|
|
return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
|
|
|
|
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
|
|
|
|
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
|
if expert_out is not None:
|
|
expert_out.handle.wait()
|
|
output[expert_out.indices] = expert_out.data
|
|
expert_out = None
|
|
|
|
# reduce scatter last output
|
|
if _expert_out is not None:
|
|
expert_out = Capsule(
|
|
*ReduceScatter.apply(_expert_out.data, self.ep_group, True),
|
|
indices=_expert_out.indices,
|
|
)
|
|
_expert_out = None
|
|
|
|
# all gather next input
|
|
if 0 <= i < NUM_CHUNK:
|
|
_expert_in = Capsule(
|
|
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
|
|
indices=get_chunk_slice(i, chunk_size),
|
|
)
|
|
|
|
# compute
|
|
if expert_in is not None:
|
|
expert_in.handle.wait()
|
|
_expert_out = Capsule(
|
|
self.experts(expert_in.data, expert_in.indices),
|
|
handle=None,
|
|
indices=expert_in.indices,
|
|
)
|
|
expert_in = None
|
|
|
|
if _expert_in is not None:
|
|
expert_in = _expert_in
|
|
_expert_in = None
|
|
|
|
return output
|
|
|
|
|
|
def apply_load_balance(model: nn.Module, optim: Any) -> None:
|
|
"""
|
|
apply load balance to every experts in the model
|
|
"""
|
|
|
|
def _apply_recursive(module: nn.Module):
|
|
for _, sub_module in module.named_children():
|
|
if isinstance(sub_module, SparseMLP):
|
|
if sub_module.enable_load_balance == True:
|
|
sub_module.load_balancer.balance_load(optim)
|
|
_apply_recursive(sub_module)
|
|
|
|
torch.cuda.empty_cache()
|
|
_apply_recursive(model)
|
|
torch.cuda.empty_cache()
|