mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] fix gathering output when using tensor parallelism (#5431)
* fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert
This commit is contained in:
@@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer):
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
|
||||
param_info = {
|
||||
"param_groups": [],
|
||||
"param2id": {},
|
||||
"id2param": {},
|
||||
"param2shape": {},
|
||||
}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
packed_group = {k: v for k, v in group.items() if k != "params"}
|
||||
@@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
@@ -939,6 +945,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
@@ -1035,6 +1042,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
Reference in New Issue
Block a user