mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||
sharded_model = copy.deepcopy(org_model)
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
|
||||
org_model = org_model.cuda()
|
||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||
@@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
|
||||
data = data_gen_fn()
|
||||
|
||||
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
||||
seq_len = data["input_ids"].shape[-1]
|
||||
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
||||
times = lcm // seq_len
|
||||
input_shape = data["input_ids"].shape
|
||||
for k, v in data.items():
|
||||
if v.shape == input_shape:
|
||||
data[k] = v.repeat((1,) * (v.dim() - 1) + (times,))
|
||||
shard_test_data = {}
|
||||
for k, v in data.items():
|
||||
shard_test_data[k] = data[k].clone()
|
||||
unshard_test_data = {}
|
||||
for k, v in data.items():
|
||||
unshard_test_data[k] = data[k].clone()
|
||||
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
for k, v in shard_test_data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to("cuda").repeat(*new_shape)
|
||||
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
data_iter = iter([data])
|
||||
data_iter = iter([shard_test_data])
|
||||
sharded_output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
sharded_model,
|
||||
@@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
return_outputs=True,
|
||||
)
|
||||
sharded_loss = sharded_output["loss"]
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
sharded_output = sharded_model(**data)
|
||||
|
||||
else:
|
||||
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
|
||||
sharded_output = sharded_model(**shard_test_data)
|
||||
sharded_loss = criterion(sharded_output)
|
||||
sharded_optimizer.backward(sharded_loss)
|
||||
|
||||
org_model.train()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
org_output = org_model(**data)
|
||||
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in unshard_test_data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
||||
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
|
||||
org_output = org_model(**unshard_test_data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
@@ -212,7 +213,6 @@ def check_output_hidden_state(
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
dim: int = 0,
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
|
Reference in New Issue
Block a user