[shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert
This commit is contained in:
flybird11111
2024-03-18 15:55:11 +08:00
committed by GitHub
parent f2e8b9ef9f
commit 5e16bf7980
6 changed files with 32 additions and 13 deletions

View File

@@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
@@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
@@ -939,6 +945,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
@@ -1035,6 +1042,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
)
self.amp_config = dict(
initial_scale=initial_scale,