mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[gemini] gemini support tensor parallelism. (#4942)
* [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commitfbf3c09e67
. * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit479900c139
. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
@@ -19,8 +20,9 @@ from colossalai.checkpoint_io.utils import (
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
)
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
@@ -32,7 +34,25 @@ __all__ = ["GeminiPlugin"]
|
||||
SUPPORTED_PRECISION = ["fp16", "bf16"]
|
||||
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
||||
|
||||
DP_AXIS = 0
|
||||
TP_AXIS = 1
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
# 1. A mapping from integer param_id to param32 shape.
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {"id2shape": {}}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
for param_id, param in enumerate(group["params"], start_index):
|
||||
original_shape = param.shape if isinstance(param, torch.Tensor) else None
|
||||
param_info["id2shape"][param_id] = original_shape
|
||||
|
||||
start_index += len(group["params"])
|
||||
|
||||
return param_info
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -284,6 +304,16 @@ class GeminiPlugin(DPPluginBase):
|
||||
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
|
||||
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||
"""
|
||||
|
||||
@@ -317,6 +347,14 @@ class GeminiPlugin(DPPluginBase):
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
enable_tensor_parallelism: bool = False,
|
||||
tp_size: int = 1,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -355,8 +393,32 @@ class GeminiPlugin(DPPluginBase):
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
)
|
||||
self.enable_tensor_parallelism = enable_tensor_parallelism
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_overlap = enable_sequence_overlap
|
||||
self.verbose = verbose
|
||||
|
||||
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
|
||||
self.dp_size = dist.get_world_size() // self.tp_size
|
||||
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
enable_tensor_parallelism=self.enable_tensor_parallelism,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
||||
enable_sequence_overlap=self.enable_sequence_overlap,
|
||||
)
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -380,6 +442,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
optimizer_params_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
# convert model to sync bn
|
||||
# FIXME(ver217): gemini does not support sync bn
|
||||
@@ -391,11 +454,21 @@ class GeminiPlugin(DPPluginBase):
|
||||
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||
|
||||
# wrap the model with Gemini
|
||||
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
|
||||
if self.enable_tensor_parallelism:
|
||||
shardformer = ShardFormer(self.shard_config)
|
||||
model, _ = shardformer.optimize(model)
|
||||
|
||||
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
|
||||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(
|
||||
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
||||
optimizer,
|
||||
model,
|
||||
**self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
tp_group=self.tp_group,
|
||||
optimizer_params_info=optimizer_params_info,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
@@ -407,4 +480,4 @@ class GeminiPlugin(DPPluginBase):
|
||||
return GeminiCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
Reference in New Issue
Block a user