mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -25,6 +26,7 @@ def build_model(model_fn,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
use_lazy_init: bool = False):
|
||||
# create new model
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
@@ -38,7 +40,8 @@ def build_model(model_fn,
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
@@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
return loss
|
||||
|
||||
data = data_gen_fn()
|
||||
|
||||
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
||||
seq_len = data['input_ids'].shape[1]
|
||||
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
||||
times = lcm // seq_len
|
||||
input_shape = data['input_ids'].shape
|
||||
for k, v in data.items():
|
||||
if v.shape == input_shape:
|
||||
data[k] = v.repeat(1, times)
|
||||
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
|
Reference in New Issue
Block a user