mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 11:37:14 +00:00
[feat] zerobubble support moehybridplugin;
This commit is contained in:
parent
1342a983b1
commit
d63479553c
@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
|
|||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
|
||||||
"""Called before backward.
|
"""Called before backward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -85,13 +85,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
master_params.append(master_p)
|
master_params.append(master_p)
|
||||||
group["params"] = master_params
|
group["params"] = master_params
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
loss = self.mixed_precision.pre_backward(loss)
|
loss = self.mixed_precision.pre_backward(loss)
|
||||||
loss.backward(*args, **kwargs)
|
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||||
tensor.backward(grad)
|
torch.autograd.backward(
|
||||||
|
tensors=tensor,
|
||||||
|
grad_tensors=grad,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
def zero_grad(self, *args, **kwargs):
|
||||||
for p in self.working_to_master_map.keys():
|
for p in self.working_to_master_map.keys():
|
||||||
|
@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
|||||||
growth_interval=growth_interval,
|
growth_interval=growth_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
|
||||||
scaled_loss = self.scale_loss(loss)
|
scaled_loss = self.scale_loss(loss)
|
||||||
scaled_loss.backward(*args, **kwargs)
|
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def step(self, *args, **kwargs) -> Optional[float]:
|
def step(self, *args, **kwargs) -> Optional[float]:
|
||||||
out = self.scaler.step(self.optim, *args, **kwargs)
|
out = self.scaler.step(self.optim, *args, **kwargs)
|
||||||
|
@ -28,7 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
@ -288,7 +288,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -315,7 +315,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -332,7 +332,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -512,7 +512,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -529,7 +529,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -538,7 +538,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -554,7 +554,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -768,7 +768,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, inputs=None, retain_graph=False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -784,7 +784,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, retain_graph)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -793,7 +793,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -809,7 +809,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward_by_grad method to compute gradients.
|
# Call the superclass backward_by_grad method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -1013,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
scheduler_nodes: List = None,
|
||||||
num_layers_per_stage: Optional[List[int]] = None,
|
num_layers_per_stage: Optional[List[int]] = None,
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
@ -1029,6 +1030,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
not pp_style == "zbv" or scheduler_nodes is not None
|
||||||
|
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = (
|
self.sequence_parallelism_mode = (
|
||||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
@ -1088,29 +1092,39 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.scheduler = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert (
|
||||||
|
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||||
|
), "num_model_chunks must be 1 when using 1f1b"
|
||||||
|
assert (
|
||||||
|
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||||
|
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
assert (
|
assert (
|
||||||
self.zero_stage <= 1
|
self.zero_stage <= 1
|
||||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||||
|
if pp_style == "zbv":
|
||||||
|
self.logger.warning(
|
||||||
|
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
|
||||||
|
)
|
||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=(pp_style == "interleaved"),
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pp_style == "interleaved":
|
if pp_style == "interleaved":
|
||||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||||
self.schedule = InterleavedSchedule(
|
self.scheduler = InterleavedSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
@ -1119,12 +1133,20 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
)
|
)
|
||||||
elif pp_style == "1f1b":
|
elif pp_style == "1f1b":
|
||||||
self.schedule = OneForwardOneBackwardSchedule(
|
self.scheduler = OneForwardOneBackwardSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
|
elif pp_style == "zbv":
|
||||||
|
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||||
|
stage_manager=self.stage_manager,
|
||||||
|
schedule=scheduler_nodes,
|
||||||
|
num_model_chunks=num_model_chunks,
|
||||||
|
num_microbatch=num_microbatches,
|
||||||
|
microbatch_size=microbatch_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
@ -1236,7 +1258,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
|
|
||||||
# Replace with distributed implementation if exists
|
# Replace with distributed implementation if exists
|
||||||
optimizer = cast_to_distributed(optimizer)
|
optimizer = cast_to_distributed(optimizer)
|
||||||
|
|
||||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||||
@ -1352,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
with ctx, model._wait_all_gather():
|
with ctx, model._wait_all_gather():
|
||||||
outputs = self.schedule.forward_backward_step(
|
outputs = self.scheduler.forward_backward_step(
|
||||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -280,14 +280,17 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.scheduler = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert (
|
assert (
|
||||||
pp_style == "interleaved" or pp_style == "zbv"
|
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||||
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
), "num_model_chunks must be 1 when using 1f1b"
|
||||||
|
assert (
|
||||||
|
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||||
|
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
@ -300,11 +303,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if pp_style == "interleaved":
|
if pp_style == "interleaved":
|
||||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||||
self.schedule = InterleavedSchedule(
|
self.scheduler = InterleavedSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
@ -313,14 +317,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
)
|
)
|
||||||
elif pp_style == "1f1b":
|
elif pp_style == "1f1b":
|
||||||
self.schedule = OneForwardOneBackwardSchedule(
|
self.scheduler = OneForwardOneBackwardSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
elif pp_style == "zbv":
|
elif pp_style == "zbv":
|
||||||
self.schedule = ZeroBubbleVPipeScheduler(
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
|
||||||
|
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||||
schedule=scheduler_nodes,
|
schedule=scheduler_nodes,
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
|
@ -136,7 +136,11 @@ class PipelineStageManager:
|
|||||||
if not self.is_interleave or ignore_chunk:
|
if not self.is_interleave or ignore_chunk:
|
||||||
return self.stage == self.num_stages - 1
|
return self.stage == self.num_stages - 1
|
||||||
else:
|
else:
|
||||||
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
# use zero bubble pipeline
|
||||||
|
if self.use_zbv:
|
||||||
|
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
|
||||||
|
else:
|
||||||
|
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_stages(self) -> int:
|
def num_stages(self) -> int:
|
||||||
|
@ -234,14 +234,28 @@ class MixtralPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
|
||||||
if stage_manager.is_first_stage():
|
|
||||||
held_layers.append(module.embed_tokens)
|
|
||||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
|
||||||
if stage_manager.is_last_stage():
|
|
||||||
held_layers.append(module.norm)
|
|
||||||
|
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
assert stage_manager.num_model_chunks is not None
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
stage_manager.stage_indices = stage_indices
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
for start_idx, end_idx in stage_indices:
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
else:
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.norm)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,17 +7,28 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.booster.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from colossalai.testing.random import seed_all
|
||||||
|
from tests.test_moe.moe_utils import assert_loose_close
|
||||||
|
|
||||||
|
NUM_BATCH = 8
|
||||||
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||||
|
NUM_LAYERS = 8
|
||||||
|
HIDDEN_SIZE_PER_HEAD = 4
|
||||||
|
NUM_HEADS = 4
|
||||||
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
@ -730,127 +741,165 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
||||||
|
|
||||||
|
|
||||||
# TODO:4) support Hybrid base 3)
|
# TODO:3) support booster & Hybrid base 2)
|
||||||
def run_with_hybridplugin(test_config):
|
def run_with_hybridplugin(test_config):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# TODO:5) support MoEHybrid base 3)
|
# TODO:4) support booster & MoEHybrid base 2)
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"config",
|
||||||
[
|
[
|
||||||
{
|
(0, 1, 4, 1, 1),
|
||||||
"pp_style": "zbv",
|
# (0, 2, 2, 1, 1),
|
||||||
"tp_size": 1,
|
# (0, 2, 1, 2, 1),
|
||||||
"ep_size": 1,
|
# (0, 2, 1, 1, 2),
|
||||||
"pp_size": 4,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"precision": "bf16",
|
|
||||||
"num_model_chunks": 2,
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_moehybridplugin(test_config):
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||||
# test_config["use_lazy_init"] = False
|
num_microbatches = pp_size
|
||||||
test_config["initial_scale"] = 2**16
|
dist.get_world_size()
|
||||||
model_list = [
|
rank = dist.get_rank()
|
||||||
"transformers_bert",
|
dtype, precision = torch.float16, "fp16"
|
||||||
]
|
torch.cuda.set_device(dist.get_rank())
|
||||||
clear_layout_converter()
|
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
########
|
||||||
if name in model_list:
|
# init base model
|
||||||
# base param
|
########
|
||||||
model = model_fn()
|
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||||
data = data_gen_fn()
|
config = MixtralConfig(
|
||||||
print(f"data {data}")
|
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||||
criterion = loss_fn
|
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||||
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
num_hidden_layers=NUM_LAYERS,
|
||||||
|
num_attention_heads=NUM_HEADS,
|
||||||
|
num_key_value_heads=NUM_HEADS,
|
||||||
|
num_local_experts=NUM_EXPERTS,
|
||||||
|
num_experts_per_tok=TOP_K,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
|
|
||||||
output = model(**data)
|
# init model with the same seed
|
||||||
loss = criterion(output)
|
seed_all(10086)
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
print(f"output {output}")
|
|
||||||
|
|
||||||
# # pp param
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||||
# model_pp = deepcopy(model)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
# data_pp = deepcopy(data)
|
# init schedule
|
||||||
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
|
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||||
|
mem_f = 34 * h + 5 * a * s
|
||||||
|
mem_w = -32 * h
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
graph = PipelineGraph(
|
||||||
|
n_stage=pp_size,
|
||||||
|
n_micro=num_microbatches,
|
||||||
|
f_cost=1,
|
||||||
|
b_cost=1,
|
||||||
|
w_cost=1,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
# max_mem=mem_f * (p * 2 + m_offset),
|
||||||
|
)
|
||||||
|
|
||||||
# # init pipeline graph
|
zbv_schedule = graph.get_v_schedule()
|
||||||
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
|
|
||||||
# mem_f = 34 * h + 5 * a * s
|
|
||||||
# mem_w = -32 * h
|
|
||||||
# mem_b = -mem_w - mem_f
|
|
||||||
# graph = PipelineGraph(
|
|
||||||
# n_stage=test_config["pp_size"],
|
|
||||||
# n_micro=test_config["num_microbatches"],
|
|
||||||
# f_cost=1,
|
|
||||||
# b_cost=1,
|
|
||||||
# w_cost=1,
|
|
||||||
# c_cost=1,
|
|
||||||
# f_mem=mem_f,
|
|
||||||
# b_mem=mem_b,
|
|
||||||
# w_mem=mem_w,
|
|
||||||
# # max_mem=mem_f * (p * 2 + m_offset),
|
|
||||||
# )
|
|
||||||
|
|
||||||
# zbv_schedule = graph.get_v_schedule()
|
# init MoeHybridPlugin
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
pp_size=pp_size,
|
||||||
|
num_microbatches=pp_size,
|
||||||
|
tp_size=tp_size,
|
||||||
|
sp_size=sp_size,
|
||||||
|
ep_size=ep_size,
|
||||||
|
zero_stage=stage,
|
||||||
|
enable_sequence_parallelism=sp_size > 1,
|
||||||
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||||
|
overlap_communication=False,
|
||||||
|
initial_scale=1,
|
||||||
|
precision=precision,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
pp_style="zbv",
|
||||||
|
scheduler_nodes=zbv_schedule,
|
||||||
|
num_model_chunks=2,
|
||||||
|
)
|
||||||
|
|
||||||
# test_config["scheduler_nodes"] = zbv_schedule
|
dp_size = plugin.dp_size
|
||||||
# plugin = MoeHybridParallelPlugin(
|
|
||||||
# **test_config
|
|
||||||
# )
|
|
||||||
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
|
|
||||||
# model = model_pp,
|
|
||||||
# optimizer = optimizer_pp,
|
|
||||||
# criterion = criterion,
|
|
||||||
# dataloader = data_pp,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# output_pp = plugin.execute_pipeline(
|
booster = Booster(plugin=plugin)
|
||||||
# data_iter=iter(data),
|
|
||||||
# model=model,
|
|
||||||
# criterion=criterion,
|
|
||||||
# optimizer=optimizer,
|
|
||||||
# return_loss = True,
|
|
||||||
# return_outputs = True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
########
|
||||||
|
# init pp model
|
||||||
|
########
|
||||||
|
|
||||||
# TODO:6) support booster & Hybrid base 4)
|
parallel_model = deepcopy(torch_model)
|
||||||
|
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
|
||||||
|
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
|
||||||
|
# create different input along dp axis
|
||||||
|
seed_all(1453 + rank)
|
||||||
|
|
||||||
|
torch_model.train()
|
||||||
|
parallel_model.train()
|
||||||
|
for _ in range(2):
|
||||||
|
# gen random input
|
||||||
|
input_embeddings = torch.rand(
|
||||||
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||||
|
).cuda()
|
||||||
|
dist.all_reduce(
|
||||||
|
input_embeddings, group=plugin.pp_group
|
||||||
|
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||||
|
|
||||||
# TODO:7) support booster & MoEHybrid base 4)
|
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
||||||
@parameterize(
|
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
||||||
"test_config",
|
|
||||||
[
|
# run the model with hybrid parallel
|
||||||
{
|
if booster.plugin.stage_manager is not None:
|
||||||
"pp_style": "zbv",
|
# for test with pp
|
||||||
"tp_size": 1,
|
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||||
"ep_size": 1,
|
sharded_output = booster.execute_pipeline(
|
||||||
"pp_size": 4,
|
data_iter,
|
||||||
"num_microbatches": 4,
|
parallel_model,
|
||||||
"zero_stage": 1,
|
lambda x, y: x.last_hidden_state.mean(),
|
||||||
"precision": "bf16",
|
parallel_optimizer,
|
||||||
"num_model_chunks": 2,
|
return_loss=True,
|
||||||
},
|
return_outputs=True,
|
||||||
],
|
)
|
||||||
)
|
# stage 0 chunk 0
|
||||||
def run_with_booster_moehybridplugin(test_config):
|
parallel_output = None
|
||||||
pass
|
if rank == dist.get_process_group_ranks(plugin.pp_group)[0]:
|
||||||
|
parallel_output = sharded_output["loss"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# for test without pp
|
||||||
|
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
||||||
|
parallel_optimizer.backward(parallel_output)
|
||||||
|
parallel_optimizer.step()
|
||||||
|
parallel_optimizer.zero_grad()
|
||||||
|
# dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||||
|
|
||||||
|
# ===================================================================================
|
||||||
|
# run normal model with all dp(different) inputs
|
||||||
|
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
||||||
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
|
torch_output_sum = 0
|
||||||
|
for input_data_ in all_inputs:
|
||||||
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
|
torch_output.backward()
|
||||||
|
torch_output_sum += torch_output.detach()
|
||||||
|
# avg dp grads follows zero optimizer
|
||||||
|
for p in torch_model.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad /= dp_size
|
||||||
|
torch_optimizer.step()
|
||||||
|
torch_optimizer.zero_grad()
|
||||||
|
if rank == dist.get_process_group_ranks(plugin.pp_group)[0]:
|
||||||
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
# run_fwd_bwd_iter_input()
|
# run_fwd_bwd_vschedule_with_optim()
|
||||||
run_fwd_bwd_vschedule_with_optim()
|
run_with_booster_moehybridplugin()
|
||||||
# run_with_moehybridplugin()
|
|
||||||
# run_with_booster_moehybridplugin()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
Loading…
Reference in New Issue
Block a user