[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)

* support gpt2 seq parallel with pp/dp/tp

* fix a bug when waiting for stream done

* delete unused gpt2_seq file
This commit is contained in:
Bin Jia
2023-08-18 11:21:53 +08:00
committed by GitHub
parent a78daf6180
commit 7c8be77081
6 changed files with 268 additions and 240 deletions

View File

@@ -6,8 +6,7 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -50,8 +49,6 @@ class GPT2Policy(Policy):
target_module=col_nn.DropoutForParallelInput,
),
])
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@@ -126,6 +123,7 @@ class GPT2Policy(Policy):
})
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
@@ -169,7 +167,13 @@ class GPT2Policy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)