mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[gemini] gemini support tensor parallelism. (#4942)
* [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commitfbf3c09e67
. * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit479900c139
. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.distributed as dist
|
||||
from packaging.version import Version
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
@@ -19,6 +20,18 @@ from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import GeminiDDP
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
init_tensor_as_customization_distributed,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
get_global_shape,
|
||||
init_as_dtensor
|
||||
)
|
||||
|
||||
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
|
||||
|
||||
@@ -93,6 +106,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
tp_group: ProcessGroup = None,
|
||||
optimizer_params_info=None,
|
||||
verbose: bool = False,
|
||||
**defaults: Any,
|
||||
):
|
||||
@@ -109,6 +124,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
self.tp_group = tp_group
|
||||
self.optimizer_params_info = optimizer_params_info
|
||||
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
|
||||
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
|
||||
self.verbose = verbose
|
||||
self.param_groups_backup = list()
|
||||
|
||||
@@ -406,8 +425,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
param = self.id_to_real_params[param_id]
|
||||
fake_param = self.id_to_fake_params.get(param_id, None)
|
||||
chunk = self.chunk_manager.get_chunk(param)
|
||||
process_group = chunk.torch_pg
|
||||
rank = dist.get_rank(process_group)
|
||||
dp_group = chunk.torch_pg
|
||||
rank = dist.get_rank(dp_group)
|
||||
master_rank = 0
|
||||
collected_states = {}
|
||||
|
||||
@@ -415,9 +434,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
local_state_names = None
|
||||
if fake_param is not None:
|
||||
local_state_names = list(self.optim.state[fake_param].keys())
|
||||
gathered_state_names = [None for _ in range(dist.get_world_size(process_group))]
|
||||
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
|
||||
dist.barrier()
|
||||
dist.all_gather_object(gathered_state_names, local_state_names)
|
||||
dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
|
||||
state_names = None
|
||||
for names in gathered_state_names:
|
||||
if names is not None:
|
||||
@@ -436,6 +455,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
# Every rank is collector when only_rank_0 is False.
|
||||
is_collector = (rank == master_rank) or (not only_rank_0)
|
||||
|
||||
# get tensor parallelism information
|
||||
is_dtensor = is_distributed_tensor(param)
|
||||
is_customized_distributed = is_customized_distributed_tensor(param)
|
||||
shard_spec = get_sharding_spec(param) if is_dtensor else None
|
||||
device_mesh = get_device_mesh(param) if is_dtensor else None
|
||||
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
||||
|
||||
# If the chunk is kept gathered,
|
||||
# the parameteres are treated the same as that of those in strict DDP during training.
|
||||
# So states can be directly fetched from current device.
|
||||
@@ -451,7 +477,18 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
).cpu()
|
||||
else:
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
if is_dtensor:
|
||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||
state_tensor = init_as_dtensor(state_tensor,
|
||||
device_mesh=device_mesh,
|
||||
sharding_spec=shard_spec,
|
||||
global_shape = global_shape)
|
||||
elif is_customized_distributed:
|
||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
|
||||
collected_states[state_name] = state_tensor.reshape(global_shape)
|
||||
return collected_states
|
||||
|
||||
# Check whether the param with given id is managed by current process.
|
||||
@@ -473,7 +510,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
_, shard_offset, shard_size = self.get_offsets(param_id)
|
||||
|
||||
# Collectors gather state shards through all_gathering.
|
||||
gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))]
|
||||
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
|
||||
|
||||
dist.barrier()
|
||||
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
|
||||
@@ -494,6 +531,16 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
for state_name, state_tensor in collected_states.items():
|
||||
if state_tensor.numel() == param.numel():
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
if is_dtensor:
|
||||
state_tensor = state_tensor.to(param.device)
|
||||
state_tensor = init_as_dtensor(state_tensor,
|
||||
sharding_spec=shard_spec,
|
||||
device_mesh=device_mesh,
|
||||
global_shape=global_shape)
|
||||
elif is_customized_distributed:
|
||||
state_tensor = state_tensor.to(param.device)
|
||||
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
|
||||
return collected_states
|
||||
|
||||
@@ -658,6 +705,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
ret_val = torch.zeros(
|
||||
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
|
||||
)
|
||||
|
||||
if is_dtensor:
|
||||
value = torch.reshape(value, global_shape)
|
||||
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
|
||||
elif is_customized_distributed:
|
||||
value = torch.reshape(value, global_shape)
|
||||
value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn)
|
||||
|
||||
ret_val.copy_(value.flatten()[state_start:state_end])
|
||||
return ret_val
|
||||
|
||||
@@ -668,6 +723,15 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
# Copy states assigned to param (and cast tensors to appropriate types).
|
||||
updated_states = dict()
|
||||
|
||||
# get tensor parallelism information
|
||||
real_param = self.id_to_real_params[param_id]
|
||||
is_dtensor = is_distributed_tensor(real_param)
|
||||
is_customized_distributed = is_customized_distributed_tensor(real_param)
|
||||
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
|
||||
device_mesh = get_device_mesh(real_param) if is_dtensor else None
|
||||
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
||||
|
||||
for k, v in saved_states.items():
|
||||
updated_states[k] = cast(fake_param, state_range, v, k)
|
||||
del v # clean loaded states
|
||||
|
Reference in New Issue
Block a user