mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)
* support gpt2 seq parallel with pp/dp/tp * fix a bug when waiting for stream done * delete unused gpt2_seq file
This commit is contained in:
@@ -6,8 +6,7 @@ from torch import Tensor, nn
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
|
||||
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
@@ -50,8 +49,6 @@ class GPT2Policy(Policy):
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
|
||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
@@ -126,6 +123,7 @@ class GPT2Policy(Policy):
|
||||
})
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
||||
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
||||
|
||||
@@ -169,7 +167,13 @@ class GPT2Policy(Policy):
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user