mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
|
||||
)
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_forward
|
||||
from .parallel_module import ParallelModule
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
|
||||
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
|
||||
|
||||
|
||||
class Embedding1D(ParallelModule):
|
||||
@@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
|
||||
return output_parallel
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(ParallelModule):
|
||||
class PaddingEmbedding(PaddingParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.padding_idx = padding_idx
|
||||
if num_embeddings % make_vocab_size_divisible_by != 0:
|
||||
self.num_embeddings = (
|
||||
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
|
||||
)
|
||||
# create weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
super().__init__(self.num_embeddings, num_embeddings, weight)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.normal_(self.weight)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the origin attributes
|
||||
num_embeddings = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
padding_idx = module.padding_idx
|
||||
device = module.weight.device
|
||||
# create the parallel module
|
||||
padding_embedding = PaddingEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return padding_embedding
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
@@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
process_group: ProcessGroup = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_args = args
|
||||
@@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
|
||||
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings = self.num_embeddings_per_partition
|
||||
# generate weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
# calculate new padding size
|
||||
multiple = make_vocab_size_divisible_by * tensor_parallel_size
|
||||
if num_embeddings % multiple != 0:
|
||||
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
|
||||
|
||||
# resize vocabulary size
|
||||
super().__init__(self.num_embeddings, num_embeddings, weight)
|
||||
|
||||
# deal with tensor parallelism
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
@@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# parameter
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
@@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
@@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
output_parallel = F.embedding(
|
||||
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
||||
)
|
||||
|
||||
# Mask the output embedding.
|
||||
embedding_output = output_parallel.clone()
|
||||
embedding_output[input_mask, :] = 0.0
|
||||
|
Reference in New Issue
Block a user