mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
This commit is contained in:
@@ -11,7 +11,11 @@ from transformers.utils import logging
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
@@ -203,6 +207,13 @@ class ChatGLMPipelineForwards:
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
if output_hidden_states:
|
||||
@@ -235,6 +246,13 @@ class ChatGLMPipelineForwards:
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if stage_manager.is_last_stage():
|
||||
@@ -329,7 +347,9 @@ class ChatGLMPipelineForwards:
|
||||
return transformer_outputs
|
||||
|
||||
|
||||
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -381,13 +401,27 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
if sp_mode in ["all_to_all"] and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
# Run encoder.
|
||||
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if sp_mode in ["split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
@@ -397,11 +431,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if sp_mode in ["split_gather"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
@@ -423,3 +465,158 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
|
||||
from .chatglm2_6b.modeling_chatglm import apply_rotary_pos_emb, split_tensor_along_last_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=None,
|
||||
use_cache=True,
|
||||
):
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(
|
||||
query_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
key_layer = key_layer.view(
|
||||
key_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.view(
|
||||
value_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
sq, bs, _, _ = value_layer.size()
|
||||
|
||||
query_layer = query_layer.reshape(sq, bs, -1)
|
||||
key_layer = key_layer.reshape(sq, bs, -1)
|
||||
value_layer = value_layer.reshape(sq, bs, -1)
|
||||
|
||||
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
|
||||
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
|
||||
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
|
||||
|
||||
query_layer = query_layer.view(
|
||||
sq * sp_size,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
).contiguous()
|
||||
|
||||
key_layer = key_layer.view(
|
||||
sq * sp_size,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
).contiguous()
|
||||
|
||||
value_layer = value_layer.view(
|
||||
sq * sp_size,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
).contiguous()
|
||||
|
||||
# apply relative positional encoding (rotary embedding)
|
||||
if rotary_pos_emb is not None:
|
||||
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||
|
||||
# adjust key and value for inference
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
||||
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
||||
if use_cache:
|
||||
kv_cache = (key_layer, value_layer)
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
if self.multi_query_attention:
|
||||
key_layer = key_layer.unsqueeze(-2)
|
||||
key_layer = key_layer.expand(
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
|
||||
-1,
|
||||
)
|
||||
key_layer = key_layer.contiguous().view(
|
||||
key_layer.size()[:2]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.unsqueeze(-2)
|
||||
value_layer = value_layer.expand(
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
|
||||
-1,
|
||||
)
|
||||
value_layer = value_layer.contiguous().view(
|
||||
value_layer.size()[:2]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition // sp_size,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
|
||||
# ==================================
|
||||
# core attention computation
|
||||
# ==================================
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
if sp_mode == "all_to_all":
|
||||
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user