mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-30 06:00:00 +00:00
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commitdf705a5210. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit58ad76d466. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
219 lines
7.3 KiB
Python
219 lines
7.3 KiB
Python
import contextlib
|
|
import os
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributed.distributed_c10d import get_process_group_ranks
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.moe.manager import MOE_MANAGER
|
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
|
|
|
|
|
class ForceFP32Parameter(torch.nn.Parameter):
|
|
def half(self, memory_format=None):
|
|
return self.data.clone()
|
|
|
|
|
|
class NormalNoiseGenerator:
|
|
"""Generates a random noisy mask for logits tensor.
|
|
|
|
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
|
`E = the number of experts`.
|
|
|
|
Args:
|
|
num_experts (int): The number of experts.
|
|
"""
|
|
|
|
def __init__(self, num_experts: int):
|
|
self.normal = torch.distributions.normal.Normal(
|
|
loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
|
scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
|
|
).rsample
|
|
|
|
def __call__(self, inputs: torch.Tensor):
|
|
noisy = self.normal(inputs.shape)
|
|
return inputs + noisy
|
|
|
|
|
|
class UniformNoiseGenerator:
|
|
"""Generates a random noisy mask for logits tensor.
|
|
copied from mesh tensorflow:
|
|
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
|
Makes models more resilient to rounding errors introduced by bfloat16.
|
|
This seems particularly important for logits.
|
|
|
|
Args:
|
|
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
|
"""
|
|
|
|
def __init__(self, eps: float = 1e-2):
|
|
self.uniform = torch.distributions.uniform.Uniform(
|
|
low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
|
|
high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
|
|
).rsample
|
|
|
|
def __call__(self, inputs: torch.Tensor):
|
|
noisy = self.uniform(inputs.shape)
|
|
return inputs * noisy
|
|
|
|
|
|
def autocast_softmax(logit: torch.Tensor, dim: int):
|
|
return F.softmax(logit, dim=dim, detype=torch.float32)
|
|
|
|
|
|
def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
|
|
if noise_type is None:
|
|
return None
|
|
elif noise_type == "Jitter":
|
|
noisy_func = UniformNoiseGenerator()
|
|
elif noise_type == "Gaussian":
|
|
noisy_func = NormalNoiseGenerator(num_experts)
|
|
else:
|
|
raise NotImplementedError("Unsupported input noisy policy")
|
|
return noisy_func
|
|
|
|
|
|
def get_activation(act: str) -> Callable:
|
|
if act is None or act == "relu":
|
|
return torch.nn.ReLU()
|
|
elif act == "gelu":
|
|
return torch.nn.GELU()
|
|
elif act == "swiglu":
|
|
return SwiGLU
|
|
elif act == "silu":
|
|
return torch.nn.SiLU()
|
|
else:
|
|
raise NotImplementedError("Unsupported activation function")
|
|
|
|
|
|
def SwiGLU(x):
|
|
"""Gated linear unit activation function.
|
|
Args:
|
|
x : input array
|
|
axis: the axis along which the split should be computed (default: -1)
|
|
"""
|
|
size = x.shape[-1]
|
|
assert size % 2 == 0, "axis size must be divisible by 2"
|
|
x1, x2 = torch.split(x, size // 2, -1)
|
|
return x1 * (x2 * torch.sigmoid(x2))
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def skip_init():
|
|
"""
|
|
skip param random init
|
|
"""
|
|
|
|
def _skip_init(*args, **kwargs):
|
|
pass
|
|
|
|
init_func = {
|
|
"constant_": torch.nn.init.constant_,
|
|
"uniform_": torch.nn.init.uniform_,
|
|
"normal_": torch.nn.init.normal_,
|
|
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
|
|
"kaiming_normal_": torch.nn.init.kaiming_normal_,
|
|
"xavier_normal_": torch.nn.init.xavier_normal_,
|
|
"xavier_uniform_": torch.nn.init.xavier_uniform_,
|
|
"trunc_normal_": torch.nn.init.trunc_normal_,
|
|
}
|
|
|
|
for method_name, original_init in init_func.items():
|
|
setattr(torch.nn.init, method_name, _skip_init)
|
|
|
|
yield
|
|
|
|
for method_name, original_init in init_func.items():
|
|
setattr(torch.nn.init, method_name, original_init)
|
|
|
|
return
|
|
|
|
|
|
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
|
|
"""Returns a parameter dictionary, the key of which is the expert parallel
|
|
size of every parameter. Since the parameters in data parallelism is replicated
|
|
in each GPU, we set their ep_size to 1.
|
|
|
|
Args:
|
|
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
|
|
"""
|
|
epsize_param_dict = dict()
|
|
for param in model.parameters():
|
|
if not is_moe_tensor(param):
|
|
ep_size = 1 # set ep_size to 1 for dp parameters
|
|
else:
|
|
ep_size = dist.get_world_size(param.ep_group)
|
|
if ep_size not in epsize_param_dict:
|
|
epsize_param_dict[ep_size] = []
|
|
epsize_param_dict[ep_size].append(param)
|
|
|
|
return epsize_param_dict
|
|
|
|
|
|
def sync_moe_model_param(model: nn.Module):
|
|
"""Make sure model parameters are consistent in MoE parallel context.
|
|
|
|
Args:
|
|
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
|
"""
|
|
param_dict = get_moe_epsize_param_dict(model)
|
|
|
|
# synchronize the parameters whose dp_group is the whole world
|
|
if 1 in param_dict:
|
|
for param in param_dict[1]:
|
|
dist.broadcast(param, src=0)
|
|
|
|
for ep_size in param_dict:
|
|
# When ep_size = world_size, communication is not needed
|
|
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
|
|
for param in param_dict[ep_size]:
|
|
src_rank = get_process_group_ranks(param.dp_group)[0]
|
|
dist.broadcast(param, src=src_rank, group=param.dp_group)
|
|
|
|
|
|
def set_moe_args(config: Any, args: dict):
|
|
for k, v in args.items():
|
|
setattr(config, k, v)
|
|
|
|
|
|
def create_ep_hierarchical_group(
|
|
ep_group_ranks: List[int],
|
|
nproc_per_node: Optional[int] = None,
|
|
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
|
|
"""
|
|
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
|
|
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
|
|
"""
|
|
assert dist.is_initialized(), "Please initialize torch.distributed first."
|
|
rank = dist.get_rank()
|
|
if nproc_per_node is None:
|
|
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
|
|
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
|
nproc_per_node = int(nproc_per_node)
|
|
else:
|
|
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
|
|
num_node = dist.get_world_size() // nproc_per_node
|
|
|
|
intra_src_rank = None
|
|
ep_intra_node_group = None
|
|
for i in range(num_node):
|
|
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
|
|
group = dist.new_group(ep_intra_ranks)
|
|
if rank in ep_intra_ranks:
|
|
assert ep_intra_node_group is None
|
|
ep_intra_node_group = group
|
|
intra_src_rank = ep_intra_ranks[0]
|
|
|
|
ep_inter_node_group = None
|
|
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
|
|
if len(ep_inter_ranks) > 1:
|
|
group = dist.new_group(ep_inter_ranks)
|
|
if rank in ep_inter_ranks:
|
|
ep_inter_node_group = group
|
|
|
|
return intra_src_rank, ep_intra_node_group, ep_inter_node_group
|