mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
This commit is contained in:
@@ -1,5 +0,0 @@
|
||||
from .manager import MOE_MANAGER
|
||||
|
||||
__all__ = [
|
||||
"MOE_MANAGER",
|
||||
]
|
||||
|
@@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
||||
|
||||
class MoeInGradScaler(torch.autograd.Function):
|
||||
class EPGradScalerIn(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
@@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.ep_size = ep_size
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
@@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
||||
return grad, None
|
||||
|
||||
|
||||
class MoeOutGradScaler(torch.autograd.Function):
|
||||
class EPGradScalerOut(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
@@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
|
||||
return grad, None
|
||||
|
||||
|
||||
class DPGradScalerIn(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||
ctx.moe_dp_size = moe_dp_size
|
||||
ctx.activated_experts = activated_experts
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.moe_dp_size != ctx.activated_experts:
|
||||
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
|
||||
return grad, None, None
|
||||
|
||||
|
||||
class DPGradScalerOut(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||
ctx.moe_dp_size = moe_dp_size
|
||||
ctx.activated_experts = activated_experts
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.moe_dp_size != ctx.activated_experts:
|
||||
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
@@ -393,4 +436,7 @@ def all_to_all_uneven(
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
assert (
|
||||
inputs.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
@@ -1,442 +0,0 @@
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class LoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
experts: MLPExperts,
|
||||
gate: nn.Parameter,
|
||||
local_expert_num: int,
|
||||
expert_num: int,
|
||||
ep_group: ProcessGroup,
|
||||
dp_group: ProcessGroup,
|
||||
tolerance: Optional[float] = 0.1,
|
||||
beam_width: Optional[int] = 8,
|
||||
group_swap_factor: Optional[float] = 0.4,
|
||||
) -> None:
|
||||
self.experts: MLPExperts = experts
|
||||
self.gate: nn.Parameter = gate
|
||||
self.moe_ep_group: ProcessGroup = ep_group
|
||||
self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks
|
||||
self.moe_dp_group: ProcessGroup = dp_group
|
||||
self.tolerance = tolerance
|
||||
self.beam_width = beam_width
|
||||
self.group_swap_factor = group_swap_factor
|
||||
self.local_expert_num = local_expert_num
|
||||
self.expert_num = expert_num
|
||||
self.local_load = None
|
||||
# TODO: use a global process group mesh
|
||||
pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size
|
||||
global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size)
|
||||
self.global_dp_group = global_dp_group.get_group_along_axis(1)
|
||||
self.global_dp_rank = dist.get_rank(self.global_dp_group)
|
||||
self.global_dp_size = dist.get_world_size(self.global_dp_group)
|
||||
|
||||
def _clear_load(self) -> None:
|
||||
self.local_load = None
|
||||
|
||||
def _sync_load(self) -> Tensor:
|
||||
new_load = self.local_load.clone().detach()
|
||||
# all reduce load between ep group
|
||||
dist.all_reduce(new_load, group=self.moe_ep_group)
|
||||
# all reduce load between dp group
|
||||
dist.all_reduce(new_load, group=self.moe_dp_group)
|
||||
return new_load
|
||||
|
||||
@staticmethod
|
||||
def _get_diff_from_avg(data: List, group: int, avg: float) -> float:
|
||||
return abs(sum(data[group]) / len(data[group]) - avg)
|
||||
|
||||
@staticmethod
|
||||
def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None:
|
||||
data[group_i][index_i], data[group_j][index_j] = (
|
||||
data[group_j][index_j],
|
||||
data[group_i][index_i],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_data(data: List) -> List:
|
||||
max_value = max(max(sublist) for sublist in data)
|
||||
data = [[i / max_value for i in sublist] for sublist in data]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _get_swap_loss(
|
||||
group_swap_factor: float,
|
||||
swap_list: List,
|
||||
group_i: int,
|
||||
index_i: int,
|
||||
group_j: int,
|
||||
index_j: int,
|
||||
) -> float:
|
||||
"""
|
||||
Get swap loss. The swap loss is used to avoid the situation that
|
||||
the same index is swapped twice and the same group is swapped for multiple times.
|
||||
"""
|
||||
swap_loss = 0
|
||||
for swap in swap_list:
|
||||
for group_id, index_id in zip([group_i, group_j], [index_i, index_j]):
|
||||
# the group has been swapped
|
||||
if group_id in [swap[0], swap[2]]:
|
||||
# the index has been swapped
|
||||
# we want to avoid the situation that the same index is swapped twice
|
||||
if index_id in [swap[1], swap[3]]:
|
||||
swap_loss += 1e5
|
||||
# the index has not been swapped
|
||||
# this is acceptable but as less as possible
|
||||
else:
|
||||
swap_loss += group_swap_factor
|
||||
return swap_loss
|
||||
|
||||
@staticmethod
|
||||
def _check_convergence(data: List, avg: float, tolerance: float):
|
||||
"""
|
||||
Check whether the data is converged after swap.
|
||||
"""
|
||||
for sublist in data:
|
||||
if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _beam_search(
|
||||
self,
|
||||
inputs: Tuple[List, float, List],
|
||||
beam_width: int,
|
||||
avg: float,
|
||||
group_swap_factor: float,
|
||||
) -> List:
|
||||
"""
|
||||
Beam search for the best swap combination.
|
||||
Specifically, we swap two elements from two groups and calculate the score.
|
||||
The score is the difference between the origin group sum and the new group sum.
|
||||
The larger the score, the better the swap combination.
|
||||
|
||||
Args:
|
||||
inputs (Tuple): (data, origin_score, swap_list)
|
||||
beam_width (int): beam width for beam search
|
||||
avg (float): average value of the data
|
||||
group_swap_factor (float): group loss for group swap loss
|
||||
|
||||
Returns:
|
||||
List: results list
|
||||
"""
|
||||
data, origin_score, swap_list = inputs
|
||||
results = []
|
||||
group_num = len(data)
|
||||
group_size = len(data[0])
|
||||
origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)]
|
||||
|
||||
for group_num_i in range(group_num):
|
||||
for group_size_i in range(group_size):
|
||||
for group_num_j in range(group_num_i + 1, group_num):
|
||||
for group_size_j in range(group_size):
|
||||
new_data = deepcopy(data)
|
||||
# calculate origin group sum
|
||||
origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j]
|
||||
# swap data
|
||||
self._swap_data(
|
||||
new_data,
|
||||
group_num_i,
|
||||
group_size_i,
|
||||
group_num_j,
|
||||
group_size_j,
|
||||
)
|
||||
# calculate new group sum
|
||||
new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg(
|
||||
new_data, group_num_j, avg
|
||||
)
|
||||
# caculate score
|
||||
new_score = origin_diff - new_diff
|
||||
if new_score > 0:
|
||||
new_score = origin_score + new_score
|
||||
# get swap loss
|
||||
swap_loss = self._get_swap_loss(
|
||||
group_swap_factor,
|
||||
swap_list,
|
||||
group_num_i,
|
||||
group_size_i,
|
||||
group_num_j,
|
||||
group_size_j,
|
||||
)
|
||||
new_score = new_score - swap_loss
|
||||
# update swap list
|
||||
new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)]
|
||||
results.append((new_data, new_score, new_swap_list))
|
||||
# sort results
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
# select top k results
|
||||
results = results[:beam_width]
|
||||
return results
|
||||
|
||||
def _load_to_list(self, load: Tensor) -> List:
|
||||
load_len = len(load)
|
||||
assert load_len % self.local_expert_num == 0
|
||||
load_list = []
|
||||
tmp_list = []
|
||||
for i in range(len(load)):
|
||||
tmp_list.append(float(load[i]))
|
||||
if (i + 1) % self.local_expert_num == 0:
|
||||
load_list.append(tmp_list)
|
||||
tmp_list = []
|
||||
return load_list
|
||||
|
||||
def _search_balance(
|
||||
self,
|
||||
data: List,
|
||||
tolerance: Optional[float] = 0.1,
|
||||
beam_width: Optional[int] = 8,
|
||||
group_swap_factor: Optional[float] = 0.4,
|
||||
return_swapped_data: Optional[bool] = False,
|
||||
) -> Tuple[List, List]:
|
||||
"""
|
||||
Search for the best swap combination to balance the data within the specified tolerance.
|
||||
And return the balanced data and the swap list. The swap list is used to record the swap.
|
||||
The swap list is a list of tuples. Each tuple is a swap operation.
|
||||
|
||||
Args:
|
||||
data (List): expert load list.
|
||||
E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]]
|
||||
This means there are 4 devices and each devices has 2 experts.
|
||||
The value is the load of the expert.
|
||||
tolerance (float): tolerance for balance.
|
||||
beam_width (int): beam width for beam search.
|
||||
group_swap_factor (float): group swap factor for group swap loss.
|
||||
The bigger it is, the less times a group will be swapped.
|
||||
return_swapped_data (bool): whether to return the swapped data.
|
||||
|
||||
Returns:
|
||||
Tuple: (balanced data, swap list).
|
||||
The swap list is a list of tuples. Each tuple is a swap operation.
|
||||
E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means
|
||||
the first expert of the first device is swapped with the first expert
|
||||
of the second device.
|
||||
"""
|
||||
norm_data = self._normalize_data(data)
|
||||
avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data)
|
||||
results = [(norm_data, 0, [])]
|
||||
stop_flag = False
|
||||
|
||||
while stop_flag == False:
|
||||
new_results = []
|
||||
best_score = results[0][1]
|
||||
for i in range(len(results)):
|
||||
new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor))
|
||||
if len(new_results) == 0:
|
||||
stop_flag = True
|
||||
break
|
||||
new_results.sort(key=lambda x: x[1], reverse=True)
|
||||
new_best_score = new_results[0][1]
|
||||
if new_best_score == best_score:
|
||||
stop_flag = True
|
||||
break
|
||||
new_results = new_results[:beam_width]
|
||||
results = new_results
|
||||
for i in results:
|
||||
if self._check_convergence(results[0][0], avg, tolerance):
|
||||
stop_flag = True
|
||||
break
|
||||
|
||||
swap_list = results[0][2]
|
||||
if return_swapped_data:
|
||||
out = deepcopy(data)
|
||||
for swap in swap_list:
|
||||
self._swap_data(out, *swap)
|
||||
return out, swap_list
|
||||
else:
|
||||
return swap_list
|
||||
|
||||
@staticmethod
|
||||
def _swap_expert_single_tensor(
|
||||
weight: nn.Parameter,
|
||||
expert_idx: int,
|
||||
comm_group: ProcessGroup,
|
||||
send_first: bool,
|
||||
comm_rank: int,
|
||||
):
|
||||
# exchange weight
|
||||
local_weight = weight.data[expert_idx]
|
||||
new_weight = torch.empty_like(local_weight)
|
||||
if send_first:
|
||||
dist.send(local_weight, dst=comm_rank, group=comm_group)
|
||||
dist.recv(new_weight, src=comm_rank, group=comm_group)
|
||||
else:
|
||||
dist.recv(new_weight, src=comm_rank, group=comm_group)
|
||||
dist.send(local_weight, dst=comm_rank, group=comm_group)
|
||||
weight.data[expert_idx] = new_weight
|
||||
|
||||
def _swap_expert_param_and_optim(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
expert_idx: int,
|
||||
comm_group: ProcessGroup,
|
||||
send_first: bool,
|
||||
comm_rank: int,
|
||||
optim: LowLevelZeroOptimizer,
|
||||
):
|
||||
# need to update master and working param if master param exists
|
||||
# else just update working param
|
||||
if weight in optim.optim.state:
|
||||
master_weight_ptr = None
|
||||
working_weight_ptr = weight
|
||||
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
|
||||
else:
|
||||
master_weight_ptr = optim.working_to_master_param[id(weight)]
|
||||
working_weight_ptr = weight
|
||||
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
|
||||
|
||||
# exchange weight
|
||||
self._swap_expert_single_tensor(
|
||||
working_weight_ptr,
|
||||
expert_idx,
|
||||
comm_group,
|
||||
send_first,
|
||||
comm_rank,
|
||||
)
|
||||
if master_weight_ptr is not None:
|
||||
# TODO: exchange master weight, skip for now
|
||||
# master weight is shared by dp group
|
||||
tmp = working_weight_ptr.view(-1).split(
|
||||
working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group)
|
||||
)[dist.get_rank(self.moe_dp_group)]
|
||||
master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype))
|
||||
# exchange optim
|
||||
self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank)
|
||||
self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank)
|
||||
|
||||
def _gather_global_dp_group(self, data: Tensor) -> Tensor:
|
||||
data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)]
|
||||
dist.all_gather(data_list, data, group=self.global_dp_group)
|
||||
data_list = torch.cat(data_list, dim=0)
|
||||
return data_list
|
||||
|
||||
def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None:
|
||||
"""
|
||||
Swap moe param and optim.
|
||||
We use different strategies to swap expert and gate.
|
||||
For expert, we exchange the param and optim of the expert by p2p.
|
||||
For gate, we all gather the gate choose the part we want.
|
||||
|
||||
Args:
|
||||
swap_list (List)
|
||||
optim (LowLevelZeroOptimizer)
|
||||
"""
|
||||
# get all experts weights
|
||||
local_rank = dist.get_rank(self.moe_ep_group)
|
||||
if self.experts.gated:
|
||||
weight_list = [self.experts.wi_up, self.experts.wi_gate]
|
||||
else:
|
||||
weight_list = [self.experts.wi]
|
||||
weight_list.append(self.experts.wo)
|
||||
|
||||
# gate optim should be obtained first
|
||||
gate_shape = self.gate.shape
|
||||
# get master weight and optim
|
||||
master_gate_weight = optim.working_to_master_param[id(self.gate)]
|
||||
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
|
||||
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
|
||||
# gather
|
||||
global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape)
|
||||
global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape)
|
||||
global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape)
|
||||
assert (
|
||||
self.gate.shape
|
||||
== global_master_gate_weight.shape
|
||||
== global_gate_exp_avg.shape
|
||||
== global_gate_exp_avg_sq.shape
|
||||
)
|
||||
|
||||
for swap in swap_list:
|
||||
source_group, source_idx, target_group, target_idx = swap
|
||||
source_rank = self.moe_ep_ranks[source_group]
|
||||
target_rank = self.moe_ep_ranks[target_group]
|
||||
# exchange expert
|
||||
if local_rank in [source_group, target_group]:
|
||||
for weight in weight_list:
|
||||
if local_rank == source_group:
|
||||
self._swap_expert_param_and_optim(
|
||||
weight,
|
||||
source_idx,
|
||||
self.moe_ep_group,
|
||||
True,
|
||||
target_rank,
|
||||
optim,
|
||||
)
|
||||
elif local_rank == target_group:
|
||||
self._swap_expert_param_and_optim(
|
||||
weight,
|
||||
target_idx,
|
||||
self.moe_ep_group,
|
||||
False,
|
||||
source_rank,
|
||||
optim,
|
||||
)
|
||||
# exchange gate
|
||||
source_expert_pos = source_group * self.local_expert_num + source_idx
|
||||
target_expert_pos = target_group * self.local_expert_num + target_idx
|
||||
for gate in [
|
||||
self.gate,
|
||||
global_master_gate_weight,
|
||||
global_gate_exp_avg,
|
||||
global_gate_exp_avg_sq,
|
||||
]:
|
||||
origin_source = gate.data[source_expert_pos].clone().detach()
|
||||
origin_target = gate.data[target_expert_pos].clone().detach()
|
||||
gate.data[source_expert_pos], gate.data[target_expert_pos] = (
|
||||
origin_target,
|
||||
origin_source,
|
||||
)
|
||||
|
||||
# update gate
|
||||
global_master_gate_weight = global_master_gate_weight.view(-1).split(
|
||||
global_master_gate_weight.numel() // self.global_dp_size
|
||||
)[self.global_dp_rank]
|
||||
master_gate_weight.data.copy_(global_master_gate_weight)
|
||||
global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[
|
||||
self.global_dp_rank
|
||||
]
|
||||
gate_exp_avg.data.copy_(global_gate_exp_avg)
|
||||
global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(
|
||||
global_gate_exp_avg_sq.numel() // self.global_dp_size
|
||||
)[self.global_dp_rank]
|
||||
gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_load(self, load: Tensor) -> None:
|
||||
if len(load) != self.expert_num:
|
||||
padding_size = self.expert_num - len(load)
|
||||
padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device)
|
||||
load = torch.cat((load, padding), dim=0)
|
||||
if self.local_load is None:
|
||||
self.local_load = load
|
||||
else:
|
||||
self.local_load += load
|
||||
|
||||
@torch.no_grad()
|
||||
def balance_load(self, optim: LowLevelZeroOptimizer) -> None:
|
||||
# prepare load
|
||||
load = self._sync_load()
|
||||
load = self._load_to_list(load)
|
||||
# search balance
|
||||
swap_list = self._search_balance(load)
|
||||
if dist.get_rank() == 0:
|
||||
if len(swap_list) > 0:
|
||||
print(f"[Load Balance] Applying expert swap...")
|
||||
else:
|
||||
print(f"[Load Balance] Invalid swap, skip...")
|
||||
# swap expert and gate
|
||||
self._swap_moe_param(swap_list, optim)
|
||||
# clear load
|
||||
self._clear_load()
|
@@ -1,163 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.moe_tensor.api import get_moe_info
|
||||
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
|
||||
|
||||
class MoEManager(metaclass=SingletonMeta):
|
||||
"""MoE manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.parallel = None
|
||||
self.mode = None
|
||||
self.use_ep_inside = None
|
||||
self.world_size = None
|
||||
self._parallel_info_dict = dict()
|
||||
|
||||
# router
|
||||
self.router_aux_loss = []
|
||||
self.router_z_loss = []
|
||||
|
||||
# fixed mode
|
||||
self.pp_size = None
|
||||
self.dp_size = None
|
||||
self.ep_size = None
|
||||
|
||||
# dynamic mode
|
||||
# Users may want to set maximum expert parallel size smaller than the world size
|
||||
# since very low bandwidth across nodes may constrain the performance of MoE
|
||||
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
|
||||
self.max_ep_size = None
|
||||
|
||||
self.has_setup = False
|
||||
|
||||
@property
|
||||
def parallel_info_dict(self):
|
||||
return self._parallel_info_dict
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self.has_setup
|
||||
|
||||
def setup(
|
||||
self,
|
||||
parallel: str = None,
|
||||
mode: str = "dynamic",
|
||||
max_ep_size: int = 8,
|
||||
fixed_dp_size: int = 0,
|
||||
fixed_ep_size: int = 0,
|
||||
fixed_pp_size: int = 0,
|
||||
use_ep_inside: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Setup MoE distributed context.
|
||||
|
||||
Args:
|
||||
seed (int): Random seed. Defaults to 42.
|
||||
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
|
||||
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
|
||||
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
|
||||
In fixed mode, the ep size and dp size is fixed.
|
||||
In dynamic mode, the ep size and dp size will be changed according to num experts.
|
||||
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
|
||||
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
|
||||
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
|
||||
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
|
||||
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True.
|
||||
"""
|
||||
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
|
||||
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
|
||||
|
||||
self.parallel = parallel
|
||||
self.use_ep_inside = use_ep_inside
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
# init by mode
|
||||
self.mode = mode
|
||||
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
|
||||
if self.mode == "dynamic":
|
||||
self.max_ep_size = min(max_ep_size, self.world_size)
|
||||
else:
|
||||
assert (
|
||||
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
|
||||
), "dp_size, ep_size and pp_size should be greater than 0"
|
||||
assert (
|
||||
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
|
||||
), "dp_size, ep_size and pp_size should be int"
|
||||
self.ep_size = fixed_ep_size
|
||||
self.dp_size = fixed_dp_size
|
||||
self.pp_size = fixed_pp_size
|
||||
|
||||
self.has_setup = True
|
||||
|
||||
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
|
||||
"""Calculate the Data Parallel Group and Expert Parallel Group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_experts : int
|
||||
The number experts
|
||||
|
||||
Returns
|
||||
-------
|
||||
int, MoeParallelInfo
|
||||
number of local experts, the MoeParallelInfo of the current ep_size
|
||||
"""
|
||||
|
||||
if self.mode == "dynamic":
|
||||
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
|
||||
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
|
||||
assert gt_flag or lt_flag, (
|
||||
"Automatic experts placement dose not not support expert number"
|
||||
" is not a multiple of ep size or vice versa."
|
||||
)
|
||||
dp_size = 1 if gt_flag else self.world_size // num_experts
|
||||
ep_size = min(self.world_size // dp_size, self.max_ep_size)
|
||||
dp_size = self.world_size // ep_size
|
||||
pp_size = 1
|
||||
else:
|
||||
dp_size = self.dp_size
|
||||
ep_size = self.ep_size
|
||||
pp_size = self.pp_size
|
||||
|
||||
# Calculate the number of experts for each GPU
|
||||
if use_tp:
|
||||
num_local_experts = num_experts
|
||||
else:
|
||||
if self.mode == "dynamic":
|
||||
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
|
||||
else:
|
||||
num_local_experts = num_experts // ep_size
|
||||
|
||||
if not (ep_size in self.parallel_info_dict):
|
||||
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
|
||||
if dist.get_rank() == 0:
|
||||
if self.use_ep_inside:
|
||||
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
|
||||
else:
|
||||
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
|
||||
|
||||
return num_local_experts, self.parallel_info_dict[ep_size]
|
||||
|
||||
def reset_loss(self):
|
||||
self.router_aux_loss, self.router_z_loss = [], []
|
||||
|
||||
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
|
||||
self.router_aux_loss.append(aux_loss)
|
||||
self.router_z_loss.append(z_loss)
|
||||
|
||||
def get_loss(self):
|
||||
cur_loss = self.router_aux_loss, self.router_z_loss
|
||||
return cur_loss
|
||||
|
||||
def get_parallel(self):
|
||||
return self.parallel
|
||||
|
||||
|
||||
MOE_MANAGER = MoEManager()
|
@@ -1,218 +0,0 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed.distributed_c10d import get_process_group_ranks
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
|
||||
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
||||
`E = the number of experts`.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class UniformNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
copied from mesh tensorflow:
|
||||
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
||||
Makes models more resilient to rounding errors introduced by bfloat16.
|
||||
This seems particularly important for logits.
|
||||
|
||||
Args:
|
||||
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
|
||||
high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.uniform(inputs.shape)
|
||||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(logit: torch.Tensor, dim: int):
|
||||
return F.softmax(logit, dim=dim, detype=torch.float32)
|
||||
|
||||
|
||||
def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
|
||||
if noise_type is None:
|
||||
return None
|
||||
elif noise_type == "Jitter":
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
elif noise_type == "Gaussian":
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported input noisy policy")
|
||||
return noisy_func
|
||||
|
||||
|
||||
def get_activation(act: str) -> Callable:
|
||||
if act is None or act == "relu":
|
||||
return torch.nn.ReLU()
|
||||
elif act == "gelu":
|
||||
return torch.nn.GELU()
|
||||
elif act == "swiglu":
|
||||
return SwiGLU
|
||||
elif act == "silu":
|
||||
return torch.nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported activation function")
|
||||
|
||||
|
||||
def SwiGLU(x):
|
||||
"""Gated linear unit activation function.
|
||||
Args:
|
||||
x : input array
|
||||
axis: the axis along which the split should be computed (default: -1)
|
||||
"""
|
||||
size = x.shape[-1]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
x1, x2 = torch.split(x, size // 2, -1)
|
||||
return x1 * (x2 * torch.sigmoid(x2))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_init():
|
||||
"""
|
||||
skip param random init
|
||||
"""
|
||||
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
|
||||
init_func = {
|
||||
"constant_": torch.nn.init.constant_,
|
||||
"uniform_": torch.nn.init.uniform_,
|
||||
"normal_": torch.nn.init.normal_,
|
||||
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
|
||||
"kaiming_normal_": torch.nn.init.kaiming_normal_,
|
||||
"xavier_normal_": torch.nn.init.xavier_normal_,
|
||||
"xavier_uniform_": torch.nn.init.xavier_uniform_,
|
||||
"trunc_normal_": torch.nn.init.trunc_normal_,
|
||||
}
|
||||
|
||||
for method_name, original_init in init_func.items():
|
||||
setattr(torch.nn.init, method_name, _skip_init)
|
||||
|
||||
yield
|
||||
|
||||
for method_name, original_init in init_func.items():
|
||||
setattr(torch.nn.init, method_name, original_init)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
|
||||
"""Returns a parameter dictionary, the key of which is the expert parallel
|
||||
size of every parameter. Since the parameters in data parallelism is replicated
|
||||
in each GPU, we set their ep_size to 1.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
|
||||
"""
|
||||
epsize_param_dict = dict()
|
||||
for param in model.parameters():
|
||||
if not is_moe_tensor(param):
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
else:
|
||||
ep_size = dist.get_world_size(param.ep_group)
|
||||
if ep_size not in epsize_param_dict:
|
||||
epsize_param_dict[ep_size] = []
|
||||
epsize_param_dict[ep_size].append(param)
|
||||
|
||||
return epsize_param_dict
|
||||
|
||||
|
||||
def sync_moe_model_param(model: nn.Module):
|
||||
"""Make sure model parameters are consistent in MoE parallel context.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
param_dict = get_moe_epsize_param_dict(model)
|
||||
|
||||
# synchronize the parameters whose dp_group is the whole world
|
||||
if 1 in param_dict:
|
||||
for param in param_dict[1]:
|
||||
dist.broadcast(param, src=0)
|
||||
|
||||
for ep_size in param_dict:
|
||||
# When ep_size = world_size, communication is not needed
|
||||
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
|
||||
for param in param_dict[ep_size]:
|
||||
src_rank = get_process_group_ranks(param.dp_group)[0]
|
||||
dist.broadcast(param, src=src_rank, group=param.dp_group)
|
||||
|
||||
|
||||
def set_moe_args(config: Any, args: dict):
|
||||
for k, v in args.items():
|
||||
setattr(config, k, v)
|
||||
|
||||
|
||||
def create_ep_hierarchical_group(
|
||||
ep_group_ranks: List[int],
|
||||
nproc_per_node: Optional[int] = None,
|
||||
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
|
||||
"""
|
||||
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
|
||||
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
|
||||
"""
|
||||
assert dist.is_initialized(), "Please initialize torch.distributed first."
|
||||
rank = dist.get_rank()
|
||||
if nproc_per_node is None:
|
||||
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
|
||||
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
||||
nproc_per_node = int(nproc_per_node)
|
||||
else:
|
||||
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
|
||||
num_node = dist.get_world_size() // nproc_per_node
|
||||
|
||||
intra_src_rank = None
|
||||
ep_intra_node_group = None
|
||||
for i in range(num_node):
|
||||
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
|
||||
group = dist.new_group(ep_intra_ranks)
|
||||
if rank in ep_intra_ranks:
|
||||
assert ep_intra_node_group is None
|
||||
ep_intra_node_group = group
|
||||
intra_src_rank = ep_intra_ranks[0]
|
||||
|
||||
ep_inter_node_group = None
|
||||
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
|
||||
if len(ep_inter_ranks) > 1:
|
||||
group = dist.new_group(ep_inter_ranks)
|
||||
if rank in ep_inter_ranks:
|
||||
ep_inter_node_group = group
|
||||
|
||||
return intra_src_rank, ep_intra_node_group, ep_inter_node_group
|
Reference in New Issue
Block a user