[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)

* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384)

* [sequence parallel] add sequence parallel linear col/row support (#4336)

* add sequence parallel linear col/row support

* add annotation

* add annotation

* add support for gpt2 fused qkv linear layer

* support sequence parallel in GPT2

* add docstring and note

* add requirments

* remove unused flash-attb

* modify flash attn test

* modify flash attn setting

* modify flash attn code

* add assert before divide, rename forward function

* [shardformer/test] fix gpt2 test with seq-parallel

* [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize

* [shardformer/sequence parallel] polish code
This commit is contained in:
Bin Jia
2023-08-16 15:41:20 +08:00
committed by GitHub
parent d20dceb9a3
commit 424629fea0
12 changed files with 655 additions and 65 deletions

View File

@@ -1,4 +1,5 @@
import copy
import math
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
@@ -25,6 +26,7 @@ def build_model(model_fn,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
@@ -38,7 +40,8 @@ def build_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
@@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
return loss
data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():