[FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
This commit is contained in:
flybird11111
2024-08-06 16:29:37 +08:00
committed by GitHub
parent 53cb9606bd
commit 0c10afd372
208 changed files with 10962 additions and 2892 deletions

View File

@@ -1,5 +0,0 @@
from .manager import MOE_MANAGER
__all__ = [
"MOE_MANAGER",
]

View File

@@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
class EPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
@@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
ctx.ep_size = ep_size
return inputs
@staticmethod
@@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
class EPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
@@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
return grad, None
class DPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
return grad, None, None
class DPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
return grad, None, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
@@ -393,4 +436,7 @@ def all_to_all_uneven(
group=None,
overlap: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)

View File

@@ -1,442 +0,0 @@
from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
from colossalai.moe.manager import MOE_MANAGER
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.zero.low_level import LowLevelZeroOptimizer
class LoadBalancer:
def __init__(
self,
experts: MLPExperts,
gate: nn.Parameter,
local_expert_num: int,
expert_num: int,
ep_group: ProcessGroup,
dp_group: ProcessGroup,
tolerance: Optional[float] = 0.1,
beam_width: Optional[int] = 8,
group_swap_factor: Optional[float] = 0.4,
) -> None:
self.experts: MLPExperts = experts
self.gate: nn.Parameter = gate
self.moe_ep_group: ProcessGroup = ep_group
self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks
self.moe_dp_group: ProcessGroup = dp_group
self.tolerance = tolerance
self.beam_width = beam_width
self.group_swap_factor = group_swap_factor
self.local_expert_num = local_expert_num
self.expert_num = expert_num
self.local_load = None
# TODO: use a global process group mesh
pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size
global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size)
self.global_dp_group = global_dp_group.get_group_along_axis(1)
self.global_dp_rank = dist.get_rank(self.global_dp_group)
self.global_dp_size = dist.get_world_size(self.global_dp_group)
def _clear_load(self) -> None:
self.local_load = None
def _sync_load(self) -> Tensor:
new_load = self.local_load.clone().detach()
# all reduce load between ep group
dist.all_reduce(new_load, group=self.moe_ep_group)
# all reduce load between dp group
dist.all_reduce(new_load, group=self.moe_dp_group)
return new_load
@staticmethod
def _get_diff_from_avg(data: List, group: int, avg: float) -> float:
return abs(sum(data[group]) / len(data[group]) - avg)
@staticmethod
def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None:
data[group_i][index_i], data[group_j][index_j] = (
data[group_j][index_j],
data[group_i][index_i],
)
@staticmethod
def _normalize_data(data: List) -> List:
max_value = max(max(sublist) for sublist in data)
data = [[i / max_value for i in sublist] for sublist in data]
return data
@staticmethod
def _get_swap_loss(
group_swap_factor: float,
swap_list: List,
group_i: int,
index_i: int,
group_j: int,
index_j: int,
) -> float:
"""
Get swap loss. The swap loss is used to avoid the situation that
the same index is swapped twice and the same group is swapped for multiple times.
"""
swap_loss = 0
for swap in swap_list:
for group_id, index_id in zip([group_i, group_j], [index_i, index_j]):
# the group has been swapped
if group_id in [swap[0], swap[2]]:
# the index has been swapped
# we want to avoid the situation that the same index is swapped twice
if index_id in [swap[1], swap[3]]:
swap_loss += 1e5
# the index has not been swapped
# this is acceptable but as less as possible
else:
swap_loss += group_swap_factor
return swap_loss
@staticmethod
def _check_convergence(data: List, avg: float, tolerance: float):
"""
Check whether the data is converged after swap.
"""
for sublist in data:
if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg:
return False
return True
def _beam_search(
self,
inputs: Tuple[List, float, List],
beam_width: int,
avg: float,
group_swap_factor: float,
) -> List:
"""
Beam search for the best swap combination.
Specifically, we swap two elements from two groups and calculate the score.
The score is the difference between the origin group sum and the new group sum.
The larger the score, the better the swap combination.
Args:
inputs (Tuple): (data, origin_score, swap_list)
beam_width (int): beam width for beam search
avg (float): average value of the data
group_swap_factor (float): group loss for group swap loss
Returns:
List: results list
"""
data, origin_score, swap_list = inputs
results = []
group_num = len(data)
group_size = len(data[0])
origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)]
for group_num_i in range(group_num):
for group_size_i in range(group_size):
for group_num_j in range(group_num_i + 1, group_num):
for group_size_j in range(group_size):
new_data = deepcopy(data)
# calculate origin group sum
origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j]
# swap data
self._swap_data(
new_data,
group_num_i,
group_size_i,
group_num_j,
group_size_j,
)
# calculate new group sum
new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg(
new_data, group_num_j, avg
)
# caculate score
new_score = origin_diff - new_diff
if new_score > 0:
new_score = origin_score + new_score
# get swap loss
swap_loss = self._get_swap_loss(
group_swap_factor,
swap_list,
group_num_i,
group_size_i,
group_num_j,
group_size_j,
)
new_score = new_score - swap_loss
# update swap list
new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)]
results.append((new_data, new_score, new_swap_list))
# sort results
results.sort(key=lambda x: x[1], reverse=True)
# select top k results
results = results[:beam_width]
return results
def _load_to_list(self, load: Tensor) -> List:
load_len = len(load)
assert load_len % self.local_expert_num == 0
load_list = []
tmp_list = []
for i in range(len(load)):
tmp_list.append(float(load[i]))
if (i + 1) % self.local_expert_num == 0:
load_list.append(tmp_list)
tmp_list = []
return load_list
def _search_balance(
self,
data: List,
tolerance: Optional[float] = 0.1,
beam_width: Optional[int] = 8,
group_swap_factor: Optional[float] = 0.4,
return_swapped_data: Optional[bool] = False,
) -> Tuple[List, List]:
"""
Search for the best swap combination to balance the data within the specified tolerance.
And return the balanced data and the swap list. The swap list is used to record the swap.
The swap list is a list of tuples. Each tuple is a swap operation.
Args:
data (List): expert load list.
E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]]
This means there are 4 devices and each devices has 2 experts.
The value is the load of the expert.
tolerance (float): tolerance for balance.
beam_width (int): beam width for beam search.
group_swap_factor (float): group swap factor for group swap loss.
The bigger it is, the less times a group will be swapped.
return_swapped_data (bool): whether to return the swapped data.
Returns:
Tuple: (balanced data, swap list).
The swap list is a list of tuples. Each tuple is a swap operation.
E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means
the first expert of the first device is swapped with the first expert
of the second device.
"""
norm_data = self._normalize_data(data)
avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data)
results = [(norm_data, 0, [])]
stop_flag = False
while stop_flag == False:
new_results = []
best_score = results[0][1]
for i in range(len(results)):
new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor))
if len(new_results) == 0:
stop_flag = True
break
new_results.sort(key=lambda x: x[1], reverse=True)
new_best_score = new_results[0][1]
if new_best_score == best_score:
stop_flag = True
break
new_results = new_results[:beam_width]
results = new_results
for i in results:
if self._check_convergence(results[0][0], avg, tolerance):
stop_flag = True
break
swap_list = results[0][2]
if return_swapped_data:
out = deepcopy(data)
for swap in swap_list:
self._swap_data(out, *swap)
return out, swap_list
else:
return swap_list
@staticmethod
def _swap_expert_single_tensor(
weight: nn.Parameter,
expert_idx: int,
comm_group: ProcessGroup,
send_first: bool,
comm_rank: int,
):
# exchange weight
local_weight = weight.data[expert_idx]
new_weight = torch.empty_like(local_weight)
if send_first:
dist.send(local_weight, dst=comm_rank, group=comm_group)
dist.recv(new_weight, src=comm_rank, group=comm_group)
else:
dist.recv(new_weight, src=comm_rank, group=comm_group)
dist.send(local_weight, dst=comm_rank, group=comm_group)
weight.data[expert_idx] = new_weight
def _swap_expert_param_and_optim(
self,
weight: nn.Parameter,
expert_idx: int,
comm_group: ProcessGroup,
send_first: bool,
comm_rank: int,
optim: LowLevelZeroOptimizer,
):
# need to update master and working param if master param exists
# else just update working param
if weight in optim.optim.state:
master_weight_ptr = None
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
else:
master_weight_ptr = optim.working_to_master_param[id(weight)]
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
# exchange weight
self._swap_expert_single_tensor(
working_weight_ptr,
expert_idx,
comm_group,
send_first,
comm_rank,
)
if master_weight_ptr is not None:
# TODO: exchange master weight, skip for now
# master weight is shared by dp group
tmp = working_weight_ptr.view(-1).split(
working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group)
)[dist.get_rank(self.moe_dp_group)]
master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype))
# exchange optim
self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank)
self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank)
def _gather_global_dp_group(self, data: Tensor) -> Tensor:
data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)]
dist.all_gather(data_list, data, group=self.global_dp_group)
data_list = torch.cat(data_list, dim=0)
return data_list
def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None:
"""
Swap moe param and optim.
We use different strategies to swap expert and gate.
For expert, we exchange the param and optim of the expert by p2p.
For gate, we all gather the gate choose the part we want.
Args:
swap_list (List)
optim (LowLevelZeroOptimizer)
"""
# get all experts weights
local_rank = dist.get_rank(self.moe_ep_group)
if self.experts.gated:
weight_list = [self.experts.wi_up, self.experts.wi_gate]
else:
weight_list = [self.experts.wi]
weight_list.append(self.experts.wo)
# gate optim should be obtained first
gate_shape = self.gate.shape
# get master weight and optim
master_gate_weight = optim.working_to_master_param[id(self.gate)]
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
# gather
global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape)
global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape)
global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape)
assert (
self.gate.shape
== global_master_gate_weight.shape
== global_gate_exp_avg.shape
== global_gate_exp_avg_sq.shape
)
for swap in swap_list:
source_group, source_idx, target_group, target_idx = swap
source_rank = self.moe_ep_ranks[source_group]
target_rank = self.moe_ep_ranks[target_group]
# exchange expert
if local_rank in [source_group, target_group]:
for weight in weight_list:
if local_rank == source_group:
self._swap_expert_param_and_optim(
weight,
source_idx,
self.moe_ep_group,
True,
target_rank,
optim,
)
elif local_rank == target_group:
self._swap_expert_param_and_optim(
weight,
target_idx,
self.moe_ep_group,
False,
source_rank,
optim,
)
# exchange gate
source_expert_pos = source_group * self.local_expert_num + source_idx
target_expert_pos = target_group * self.local_expert_num + target_idx
for gate in [
self.gate,
global_master_gate_weight,
global_gate_exp_avg,
global_gate_exp_avg_sq,
]:
origin_source = gate.data[source_expert_pos].clone().detach()
origin_target = gate.data[target_expert_pos].clone().detach()
gate.data[source_expert_pos], gate.data[target_expert_pos] = (
origin_target,
origin_source,
)
# update gate
global_master_gate_weight = global_master_gate_weight.view(-1).split(
global_master_gate_weight.numel() // self.global_dp_size
)[self.global_dp_rank]
master_gate_weight.data.copy_(global_master_gate_weight)
global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[
self.global_dp_rank
]
gate_exp_avg.data.copy_(global_gate_exp_avg)
global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(
global_gate_exp_avg_sq.numel() // self.global_dp_size
)[self.global_dp_rank]
gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq)
@torch.no_grad()
def update_load(self, load: Tensor) -> None:
if len(load) != self.expert_num:
padding_size = self.expert_num - len(load)
padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device)
load = torch.cat((load, padding), dim=0)
if self.local_load is None:
self.local_load = load
else:
self.local_load += load
@torch.no_grad()
def balance_load(self, optim: LowLevelZeroOptimizer) -> None:
# prepare load
load = self._sync_load()
load = self._load_to_list(load)
# search balance
swap_list = self._search_balance(load)
if dist.get_rank() == 0:
if len(swap_list) > 0:
print(f"[Load Balance] Applying expert swap...")
else:
print(f"[Load Balance] Invalid swap, skip...")
# swap expert and gate
self._swap_moe_param(swap_list, optim)
# clear load
self._clear_load()

View File

@@ -1,163 +0,0 @@
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoEManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.parallel = None
self.mode = None
self.use_ep_inside = None
self.world_size = None
self._parallel_info_dict = dict()
# router
self.router_aux_loss = []
self.router_z_loss = []
# fixed mode
self.pp_size = None
self.dp_size = None
self.ep_size = None
# dynamic mode
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = None
self.has_setup = False
@property
def parallel_info_dict(self):
return self._parallel_info_dict
@property
def is_initialized(self):
return self.has_setup
def setup(
self,
parallel: str = None,
mode: str = "dynamic",
max_ep_size: int = 8,
fixed_dp_size: int = 0,
fixed_ep_size: int = 0,
fixed_pp_size: int = 0,
use_ep_inside: bool = True,
) -> None:
"""
Setup MoE distributed context.
Args:
seed (int): Random seed. Defaults to 42.
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
In fixed mode, the ep size and dp size is fixed.
In dynamic mode, the ep size and dp size will be changed according to num experts.
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True.
"""
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.parallel = parallel
self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size()
# init by mode
self.mode = mode
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size)
else:
assert (
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
), "dp_size, ep_size and pp_size should be greater than 0"
assert (
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size
self.has_setup = True
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
if self.mode == "dynamic":
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size
pp_size = 1
else:
dp_size = self.dp_size
ep_size = self.ep_size
pp_size = self.pp_size
# Calculate the number of experts for each GPU
if use_tp:
num_local_experts = num_experts
else:
if self.mode == "dynamic":
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
else:
num_local_experts = num_experts // ep_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
if dist.get_rank() == 0:
if self.use_ep_inside:
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
else:
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
return num_local_experts, self.parallel_info_dict[ep_size]
def reset_loss(self):
self.router_aux_loss, self.router_z_loss = [], []
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
self.router_aux_loss.append(aux_loss)
self.router_z_loss.append(z_loss)
def get_loss(self):
cur_loss = self.router_aux_loss, self.router_z_loss
return cur_loss
def get_parallel(self):
return self.parallel
MOE_MANAGER = MoEManager()

View File

@@ -1,218 +0,0 @@
import contextlib
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.distributed_c10d import get_process_group_ranks
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
class NormalNoiseGenerator:
"""Generates a random noisy mask for logits tensor.
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
`E = the number of experts`.
Args:
num_experts (int): The number of experts.
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class UniformNoiseGenerator:
"""Generates a random noisy mask for logits tensor.
copied from mesh tensorflow:
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
eps (float, optional): Epsilon in generator, defaults 1e-2.
"""
def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.uniform(inputs.shape)
return inputs * noisy
def autocast_softmax(logit: torch.Tensor, dim: int):
return F.softmax(logit, dim=dim, detype=torch.float32)
def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
if noise_type is None:
return None
elif noise_type == "Jitter":
noisy_func = UniformNoiseGenerator()
elif noise_type == "Gaussian":
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
return noisy_func
def get_activation(act: str) -> Callable:
if act is None or act == "relu":
return torch.nn.ReLU()
elif act == "gelu":
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else:
raise NotImplementedError("Unsupported activation function")
def SwiGLU(x):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[-1]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2))
@contextlib.contextmanager
def skip_init():
"""
skip param random init
"""
def _skip_init(*args, **kwargs):
pass
init_func = {
"constant_": torch.nn.init.constant_,
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"trunc_normal_": torch.nn.init.trunc_normal_,
}
for method_name, original_init in init_func.items():
setattr(torch.nn.init, method_name, _skip_init)
yield
for method_name, original_init in init_func.items():
setattr(torch.nn.init, method_name, original_init)
return
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
"""Returns a parameter dictionary, the key of which is the expert parallel
size of every parameter. Since the parameters in data parallelism is replicated
in each GPU, we set their ep_size to 1.
Args:
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
"""
epsize_param_dict = dict()
for param in model.parameters():
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = dist.get_world_size(param.ep_group)
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
return epsize_param_dict
def sync_moe_model_param(model: nn.Module):
"""Make sure model parameters are consistent in MoE parallel context.
Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""
param_dict = get_moe_epsize_param_dict(model)
# synchronize the parameters whose dp_group is the whole world
if 1 in param_dict:
for param in param_dict[1]:
dist.broadcast(param, src=0)
for ep_size in param_dict:
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
for param in param_dict[ep_size]:
src_rank = get_process_group_ranks(param.dp_group)[0]
dist.broadcast(param, src=src_rank, group=param.dp_group)
def set_moe_args(config: Any, args: dict):
for k, v in args.items():
setattr(config, k, v)
def create_ep_hierarchical_group(
ep_group_ranks: List[int],
nproc_per_node: Optional[int] = None,
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
"""
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
"""
assert dist.is_initialized(), "Please initialize torch.distributed first."
rank = dist.get_rank()
if nproc_per_node is None:
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node)
else:
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None
ep_intra_node_group = None
for i in range(num_node):
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks:
assert ep_intra_node_group is None
ep_intra_node_group = group
intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks:
ep_inter_node_group = group
return intra_src_rank, ep_intra_node_group, ep_inter_node_group