mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 04:03:58 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -11,17 +11,12 @@ from torch.nn import Module
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ..layer.parallel_module import ParallelModule
|
||||
from ..shard.shard_config import ShardConfig
|
||||
|
||||
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
|
||||
|
||||
|
||||
class ParallelModule():
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubModuleReplacementDescription:
|
||||
r"""
|
||||
@@ -231,3 +226,22 @@ class Policy(ABC):
|
||||
end_idx = num_layers_per_stage_accumulated[stage + 1]
|
||||
|
||||
return [start_idx, end_idx]
|
||||
|
||||
def append_seq_parallel_to_policy(
|
||||
self,
|
||||
suffix_list: List[str],
|
||||
module_policy_description: ModulePolicyDescription,
|
||||
):
|
||||
r"""
|
||||
Append the sequence parallel policy to the policy for the given key.
|
||||
|
||||
Args:
|
||||
suffix_list (List[str]): the suffix list of the module to be parallelized
|
||||
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
|
||||
"""
|
||||
|
||||
for sub_description in module_policy_description.sub_module_replacement:
|
||||
if (sub_description.suffix in suffix_list):
|
||||
if sub_description.kwargs is None:
|
||||
sub_description.kwargs = {}
|
||||
sub_description.kwargs["seq_parallel"] = True
|
||||
|
Reference in New Issue
Block a user