mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-11 14:43:10 +00:00
* [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commitfbf3c09e67. * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commitfced140250. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit479900c139. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
317 lines
13 KiB
Python
317 lines
13 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.lazy import LazyInitContext
|
|
from colossalai.nn import init as init
|
|
from colossalai.nn.layer.utils import divide
|
|
from colossalai.tensor.d_tensor.api import (
|
|
is_distributed_tensor,
|
|
shard_colwise,
|
|
shard_rowwise,
|
|
sharded_tensor_to_existing_param,
|
|
)
|
|
|
|
from ._operation import gather_forward_split_backward, reduce_forward
|
|
from .parallel_module import ParallelModule
|
|
from .utils import create_randomizer_with_offset
|
|
|
|
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
|
|
|
|
|
|
class Embedding1D(ParallelModule):
|
|
r"""Embedding for 1D parallelism.
|
|
|
|
Args:
|
|
num_embeddings (int): number of embeddings.
|
|
embedding_dim (int): dimension of embedding.
|
|
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
|
|
therefore, the embedding vector at padding_idx is not updated during training,
|
|
i.e. it remains as a fixed “pad”, defaults to None.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
he initializer of weight, defaults to normal initializer.
|
|
|
|
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
|
|
::
|
|
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
|
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
|
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
|
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
|
|
of frequency of the words in the mini-batch. Default False.
|
|
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
|
|
|
|
More details about ``args`` and ``kwargs`` could be found in
|
|
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: int = None,
|
|
dtype: torch.dtype = None,
|
|
device: torch.device = None,
|
|
process_group: ProcessGroup = None,
|
|
gather_output: bool = True,
|
|
weight: Optional[nn.Parameter] = None,
|
|
weight_initializer: Callable = init.normal_(),
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
self.process_group = process_group
|
|
|
|
self.padding_idx = padding_idx
|
|
self.embed_args = args
|
|
self.embed_kwargs = kwargs
|
|
self.gather_output = gather_output
|
|
|
|
# offset the seed with randomizer index and rank
|
|
seed = torch.random.initial_seed()
|
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
|
|
|
# Parameters.
|
|
if weight is None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
|
else:
|
|
weight.data = weight.data.to(device=device, dtype=dtype)
|
|
self.weight = weight
|
|
if not is_distributed_tensor(self.weight):
|
|
sharded_weight = shard_colwise(self.weight.data, process_group)
|
|
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
|
|
|
if weight is None:
|
|
with self.randomizer.fork_rng(enable_cpu=True):
|
|
self.reset_parameters(weight_initializer)
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]] = None, *args, **kwargs
|
|
) -> "Embedding1D":
|
|
r"""
|
|
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
|
"""
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes
|
|
num_embedding = module.num_embeddings
|
|
embedding_dim = module.embedding_dim
|
|
padding_idx = module.padding_idx
|
|
max_norm = module.max_norm
|
|
norm_type = module.norm_type
|
|
scale_grad_by_freq = module.scale_grad_by_freq
|
|
sparse = module.sparse
|
|
dtype = module.weight.dtype
|
|
device = module.weight.device
|
|
|
|
# sparse is not support yet
|
|
if sparse:
|
|
raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.")
|
|
|
|
embedding = Embedding1D(
|
|
num_embeddings=num_embedding,
|
|
embedding_dim=embedding_dim,
|
|
padding_idx=padding_idx,
|
|
process_group=process_group,
|
|
dtype=dtype,
|
|
device=device,
|
|
max_norm=max_norm,
|
|
norm_type=norm_type,
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
sparse=sparse,
|
|
weight=module.weight,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
return embedding
|
|
|
|
def reset_parameters(self, weight_initializer) -> None:
|
|
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
if self.padding_idx is not None:
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
|
if self.gather_output:
|
|
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
|
return output
|
|
else:
|
|
return output_parallel
|
|
|
|
|
|
class VocabParallelEmbedding1D(ParallelModule):
|
|
r"""Embedding parallelized in the vocabulary dimension.
|
|
|
|
Args:
|
|
num_embeddings (int): number of embeddings.
|
|
embedding_dim (int): dimension of embedding.
|
|
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
|
|
therefore, the embedding vector at padding_idx is not updated during training,
|
|
i.e. it remains as a fixed “pad”, defaults to None.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
he initializer of weight, defaults to normal initializer.
|
|
|
|
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
|
::
|
|
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
|
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
|
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
|
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
|
|
of frequency of the words in the mini-batch. Default False.
|
|
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
|
|
|
|
More details about ``args`` and ``kwargs`` could be found in
|
|
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
|
|
|
|
More details about initializer please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: int = None,
|
|
dtype: torch.dtype = None,
|
|
device: torch.device = None,
|
|
process_group: ProcessGroup = None,
|
|
weight: Optional[nn.Parameter] = None,
|
|
weight_initializer: Callable = init.normal_(),
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
self.embed_args = args
|
|
self.embed_kwargs = kwargs
|
|
self.process_group = process_group
|
|
|
|
tensor_parallel_size = dist.get_world_size(group=process_group)
|
|
tensor_parallel_rank = dist.get_rank(group=process_group)
|
|
|
|
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
|
self.num_embeddings = self.num_embeddings_per_partition
|
|
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
|
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
|
|
|
# padding index
|
|
self.padding_idx = self._select_padding_idx(padding_idx)
|
|
|
|
# offset the seed with randomizer index and rank
|
|
seed = torch.random.initial_seed()
|
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
|
|
|
# parameter
|
|
if weight is None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
|
else:
|
|
weight.data = weight.data.to(device=device, dtype=dtype)
|
|
self.weight = weight
|
|
if not is_distributed_tensor(self.weight):
|
|
sharded_weight = shard_rowwise(self.weight.data, process_group)
|
|
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
|
|
|
if weight is None:
|
|
self.reset_parameters(weight_initializer)
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
r"""
|
|
Convert a native pytorch embedding module to a parallel module.
|
|
"""
|
|
LazyInitContext.materialize(module)
|
|
# get the origin attributes
|
|
num_embeddings = module.num_embeddings
|
|
embedding_dim = module.embedding_dim
|
|
padding_idx = module.padding_idx
|
|
device = module.weight.device
|
|
|
|
# ensure only one process group is used
|
|
if isinstance(process_group, (list, tuple)):
|
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
|
process_group = process_group[0]
|
|
|
|
# create the parallel module
|
|
vocab_embedding_1d = VocabParallelEmbedding1D(
|
|
num_embeddings=num_embeddings,
|
|
embedding_dim=embedding_dim,
|
|
padding_idx=padding_idx,
|
|
device=device,
|
|
process_group=process_group,
|
|
weight=module.weight,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
return vocab_embedding_1d
|
|
|
|
def reset_parameters(self, weight_initializer) -> None:
|
|
with self.randomizer.fork_rng(enable_cpu=True):
|
|
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
if (
|
|
self.padding_idx is not None
|
|
and self.padding_idx >= self.vocab_start_index
|
|
and self.padding_idx < self.vocab_end_index
|
|
):
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
|
|
|
def _select_padding_idx(self, padding_idx: int):
|
|
# select padding index according to the rank
|
|
if padding_idx is None:
|
|
return None
|
|
elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index:
|
|
return padding_idx - self.vocab_start_index
|
|
else:
|
|
return None
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
# Build the mask.
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
|
# Mask the input.
|
|
masked_input = input_.clone() - self.vocab_start_index
|
|
masked_input[input_mask] = 0
|
|
|
|
output_parallel = F.embedding(
|
|
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
|
)
|
|
|
|
# Mask the output embedding.
|
|
embedding_output = output_parallel.clone()
|
|
embedding_output[input_mask, :] = 0.0
|
|
# Reduce across all the model parallel GPUs.
|
|
output = reduce_forward(embedding_output, self.process_group)
|
|
return output
|