[shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
Zhongkai Zhao
2024-04-03 17:15:47 +08:00
committed by GitHub
parent 7e0ec5a85c
commit 8e412a548e
33 changed files with 1630 additions and 256 deletions

View File

@@ -1,5 +1,4 @@
import copy
import math
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
@@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
@@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin(
data = data_gen_fn()
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data["input_ids"].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data["input_ids"].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat((1,) * (v.dim() - 1) + (times,))
shard_test_data = {}
for k, v in data.items():
shard_test_data[k] = data[k].clone()
unshard_test_data = {}
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():
for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to("cuda").repeat(*new_shape)
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
data_iter = iter([data])
data_iter = iter([shard_test_data])
sharded_output = booster.execute_pipeline(
data_iter,
sharded_model,
@@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin(
return_outputs=True,
)
sharded_loss = sharded_output["loss"]
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
else:
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
sharded_output = sharded_model(**shard_test_data)
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)
if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output)
org_loss.backward()
@@ -212,7 +213,6 @@ def check_output_hidden_state(
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3,
dim: int = 0,
):
org_hidden_state = org_output.last_hidden_state

View File

@@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
@@ -154,7 +176,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
def run_bert_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

View File

@@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,

View File

@@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,

View File

@@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,

View File

@@ -2,6 +2,8 @@ import os
import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.logging import disable_existing_loggers
@@ -46,6 +48,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"]
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
if stage_manager is None:
norm_layer_for_check.append("norm")
# Check the grad when using ZeRO-1 and ZeRO-2
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@@ -60,8 +82,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_grads = get_grad_tensors_for_check(
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
norm_layer_grads = get_grad_tensors_for_check(
llama_model,
shard_llama_model,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
@@ -98,6 +131,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,