mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -22,8 +22,8 @@ from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
@@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
return x
|
||||
|
||||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
module: Module,
|
||||
@@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||
if grads is not None:
|
||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||
@@ -489,7 +488,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
@@ -515,7 +513,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
@@ -678,7 +675,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||
super()._sync_grad()
|
||||
|
||||
@@ -923,6 +919,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
||||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -958,6 +957,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
enable_metadata_cache: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
@@ -984,17 +986,42 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=PP_AXIS,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||
self.schedule = InterleavedSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
@@ -1035,6 +1062,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
||||
def __del__(self):
|
||||
"""Destroy the prcess groups in ProcessGroupMesh"""
|
||||
self.pg_mesh.destroy_mesh_process_groups()
|
||||
|
||||
@property
|
||||
def enable_pipeline_parallelism(self) -> bool:
|
||||
return self.pp_size > 1
|
||||
@@ -1052,7 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
return True
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
@@ -1146,9 +1177,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
):
|
||||
return outputs
|
||||
|
||||
# Synchronize the grads of shared parameters of the model.
|
||||
model.sync_shared_params()
|
||||
|
||||
# Synchronize sequence parallelism gradients of the model.
|
||||
model.sync_sp_grads()
|
||||
|
||||
@@ -1212,5 +1248,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert (
|
||||
self.zero_stage != 2
|
||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
Reference in New Issue
Block a user