Merge branch 'main' into ckpt_api

This commit is contained in:
wangbluo 2024-11-25 10:29:34 +08:00
commit fa0318dba5
44 changed files with 3837 additions and 213 deletions

View File

@ -25,24 +25,20 @@
</div> </div>
## GPU Cloud HPC-AI.COM Coming ## Get Started with Colossal-AI Without Setup
For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price. Access high-end, on-demand compute for your research instantly—no setup needed.
Plus, when you refer a friend, youll receive 20% cashback or compute credits equal to 100% of their top-up!
Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance. Sign up now and get $10 in credits!
Dont miss this incredible opportunity to accelerate your AI projects!
Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10! Limited Academic Bonuses:
Special Bonuses:
* Top up $1,000 and receive 300 credits * Top up $1,000 and receive 300 credits
* Top up $500 and receive 100 credits * Top up $500 and receive 100 credits
<div align="center"> <div align="center">
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki"> <a href="https://hpc-ai.com/?utm_source=github&utm_medium=social&utm_campaign=promotion-colossalai">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" /> <img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2.gif" width="850" />
</a> </a>
</div> </div>

View File

@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
dtype: torch.dtype dtype: torch.dtype
@abstractmethod @abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor: def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
"""Called before backward. """Called before backward.
Args: Args:

View File

@ -86,13 +86,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
group["params"] = master_params group["params"] = master_params
self._current_grad_norm: Optional[float] = None self._current_grad_norm: Optional[float] = None
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss) loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs) loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad) torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
def zero_grad(self, *args, **kwargs): def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys(): for p in self.working_to_master_map.keys():

View File

@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
growth_interval=growth_interval, growth_interval=growth_interval,
) )
def backward(self, loss: Tensor, *args, **kwargs) -> None: def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
scaled_loss = self.scale_loss(loss) scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs) scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]: def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs) out = self.scaler.step(self.optim, *args, **kwargs)

View File

@ -28,7 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook from colossalai.quantization.fp8_hook import FP8Hook
@ -296,7 +296,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self._current_grad_norm: Optional[float] = None self._current_grad_norm: Optional[float] = None
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r""" r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -315,7 +315,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
with self.model._hook_context(): with self.model._hook_context():
super().backward(loss, *args, **kwargs) super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -324,7 +324,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# If gradient synchronization is is not required, return. # If gradient synchronization is is not required, return.
return return
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
""" """
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -341,7 +341,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -525,7 +525,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
max_norm=max_norm, max_norm=max_norm,
) )
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r""" r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -543,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
with self.model._hook_context(): with self.model._hook_context():
super().backward(loss, *args, **kwargs) super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -552,7 +552,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# If gradient synchronization is is not required, return. # If gradient synchronization is is not required, return.
return return
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
""" """
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -568,7 +568,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -785,7 +785,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
else: else:
return return
def backward(self, loss, retain_graph=False): def backward(self, loss, inputs=None, retain_graph=False):
""" """
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -801,7 +801,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph) super().backward(loss, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -810,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# If gradient synchronization is is not required, return. # If gradient synchronization is is not required, return.
return return
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
""" """
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -826,7 +826,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
None None
""" """
# Call the superclass backward_by_grad method to compute gradients. # Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -1030,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b", pp_style: str = "1f1b",
num_model_chunks: int = 1, num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None, num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
@ -1048,6 +1049,9 @@ class HybridParallelPlugin(PipelinePluginBase):
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if enable_sequence_parallelism: if enable_sequence_parallelism:
self.sequence_parallelism_mode = ( self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@ -1109,29 +1113,39 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.scheduler = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert ( assert (
self.zero_stage <= 1 self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
if pp_style == "zbv":
self.logger.warning(
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
)
self.stage_manager = PipelineStageManager( self.stage_manager = PipelineStageManager(
self.pg_mesh, self.pg_mesh,
pipeline_axis=self.pp_axis, pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved", enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
use_zbv=(pp_style == "zbv"),
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage, num_layers_per_stage=num_layers_per_stage,
) )
if pp_style == "interleaved": if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule( self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches, num_microbatch=num_microbatches,
@ -1141,13 +1155,21 @@ class HybridParallelPlugin(PipelinePluginBase):
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
) )
elif pp_style == "1f1b": elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule( self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
) )
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":
@ -1263,7 +1285,6 @@ class HybridParallelPlugin(PipelinePluginBase):
# Replace with distributed implementation if exists # Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning( self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
@ -1278,6 +1299,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.dp_size == 1 and self.pp_size == 1 self.dp_size == 1 and self.pp_size == 1
) )
# sync gradients across DP * SP ranks # sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks # Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
@ -1380,7 +1402,7 @@ class HybridParallelPlugin(PipelinePluginBase):
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx, model._hook_context(): with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step( outputs = self.scheduler.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )

View File

@ -141,8 +141,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") f_writer = AsyncFileWriter(
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)
else: else:
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter( f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" fp=open(checkpoint_file_path, "wb", buffering=0),
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
) )
save_nested(f_writer, shard) save_nested(f_writer, shard)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)

View File

@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
@ -212,6 +213,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b", pp_style: str = "1f1b",
num_model_chunks: int = 1, num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None, num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
@ -285,12 +287,17 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.scheduler = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@ -300,14 +307,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.stage_manager = PipelineStageManager( self.stage_manager = PipelineStageManager(
self.pg_mesh, self.pg_mesh,
pipeline_axis=self.pp_axis, pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved", enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage, num_layers_per_stage=num_layers_per_stage,
use_zbv=(pp_style == "zbv"),
) )
if pp_style == "interleaved": if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule( self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches, num_microbatch=num_microbatches,
@ -316,12 +324,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
) )
elif pp_style == "1f1b": elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule( self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
) )
elif pp_style == "zbv":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
self.scheduler = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
overlap_p2p=overlap_p2p,
)
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -61,7 +61,7 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async: if use_async:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts: if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)

View File

@ -702,7 +702,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
complete_state_dict.update(_state_dict) complete_state_dict.update(_state_dict)
if use_async: if use_async:
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
if id(model) not in self.pinned_state_dicts: if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)

View File

@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file) index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread") writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writers.append(writer) writers.append(writer)
if pinned_state_dict is not None: if pinned_state_dict is not None:

View File

@ -49,14 +49,31 @@ class OptimizerWrapper:
""" """
self.optim.zero_grad(*args, **kwargs) self.optim.zero_grad(*args, **kwargs)
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
""" """
Performs a backward pass on the loss. Performs a backward pass on the loss.
""" """
loss.backward(*args, **kwargs) loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
torch.autograd.backward(tensor, grad) """
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
for dw, we only calculate dw = x*dy here
Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): for dx, input_obj is x of current chunk;
for dw, input_obj is w of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
def state_dict(self): def state_dict(self):
""" """

View File

@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())
def torch_adam_update( def torch_adam_update(
self, self,
data, data,

View File

@ -1,11 +1,12 @@
from .p2p import PipelineP2PCommunication from .p2p import PipelineP2PCommunication
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
from .stage_manager import PipelineStageManager from .stage_manager import PipelineStageManager
__all__ = [ __all__ = [
"PipelineSchedule", "PipelineSchedule",
"OneForwardOneBackwardSchedule", "OneForwardOneBackwardSchedule",
"InterleavedSchedule", "InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
"PipelineP2PCommunication", "PipelineP2PCommunication",
"PipelineStageManager", "PipelineStageManager",
] ]

View File

@ -432,7 +432,6 @@ def _communicate(
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True, send_first=send_first if send_first != None else True,
) )
if metadata_recv is not None: if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata) assert isinstance(metadata_recv, P2PMetadata)
tree_spec = metadata_recv.tree_spec tree_spec = metadata_recv.tree_spec

View File

@ -1,9 +1,11 @@
from .base import PipelineSchedule from .base import PipelineSchedule
from .interleaved_pp import InterleavedSchedule from .interleaved_pp import InterleavedSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule from .one_f_one_b import OneForwardOneBackwardSchedule
from .zero_bubble_pp import ZeroBubbleVPipeScheduler
__all__ = [ __all__ = [
"PipelineSchedule", "PipelineSchedule",
"OneForwardOneBackwardSchedule", "OneForwardOneBackwardSchedule",
"InterleavedSchedule", "InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
] ]

View File

@ -137,6 +137,16 @@ def retain_grad(x: Any) -> None:
x.retain_grad() x.retain_grad()
def require_grad(x: Any) -> None:
"""Call require_grad on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor) and not x.requires_grad:
x.requires_grad_()
def detach(x: Any) -> Any: def detach(x: Any) -> Any:
"""Call detach() on a tensor. """Call detach() on a tensor.
@ -151,6 +161,34 @@ def detach(x: Any) -> Any:
return x return x
def clone(x: Any) -> Any:
"""Call clone() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The cloned object.
"""
if isinstance(x, torch.Tensor):
return x.clone()
return x
def release_tensor_data(x: Any) -> Any:
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
Args:
x (Any): Object to be called.
Returns:
Any: The deallocate .data object.
"""
if isinstance(x, torch.Tensor):
return x.data.untyped_storage().resize_(0)
return x
def merge_batch(data: List[Any], batch_size_dim=0) -> Any: def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch. """Merge micro batches into a batch.

View File

@ -0,0 +1,449 @@
# Refer from Zero Bubble Pipeline Parallelism.
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
# Paper: https://arxiv.org/abs/2401.10241
# The following applies to all files unless otherwise noted:
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import deque
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
chunk: int
stage: int
minibatch: int
start_time: int = 0
completion_time: int = 0
rollback: bool = False
class PipelineGraph(object):
"""PipelineGraph"""
def __init__(
self,
n_stage,
n_micro,
f_cost,
b_cost,
w_cost,
c_cost,
f_mem,
b_mem,
w_mem,
max_mem=None,
):
self.n_node = 6 * n_stage * n_micro
self.n_stage = n_stage
self.n_micro = n_micro
self.f_cost = f_cost
self.b_cost = b_cost
self.w_cost = w_cost
self.c_cost = c_cost
self.f_mem = f_mem
self.b_mem = b_mem
self.w_mem = w_mem
self.fbw_cost = [f_cost, b_cost, w_cost]
self.fbw_mem = [f_mem, b_mem, w_mem]
self.max_mem = max_mem or f_mem * self.n_stage * 2
def get_id(self, cat, chunk, stage, micro):
return (
cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro
)
def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
count = []
for i in range(self.n_stage):
count.append([0] * 6)
end_time = [-1] * self.n_node
cur_time = [0] * self.n_stage
mem = [0] * self.n_stage
stage_bubble = [0] * self.n_stage
pending_w = [deque() for _ in range(self.n_stage)]
schedule = [[] for _ in range(self.n_stage)]
stage_str = [" " * i for i in range(self.n_stage)]
if approved_bubble is None:
approved_bubble = [-1] * self.n_stage
max_approved_bubble = max(approved_bubble)
def get_max_stage_bubble(stage=-1):
max_stage_bubble = 0
for bb in stage_bubble:
max_stage_bubble = max(max_stage_bubble, bb)
if stage >= 0:
max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
return max_stage_bubble
def put_w(stage):
assert len(pending_w[stage]) > 0
_, chunk_, _ = pending_w[stage].popleft()
put(2, chunk_, stage)
def put(cat, chunk, stage, assert_cnt=True):
_tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
_cnt = count[stage][cat * 2 + chunk]
# assert _cnt < self.n_micro
if _cnt >= self.n_micro:
if not assert_cnt:
stage_str[stage] += " "
cur_time[stage] = _tmp # TODO
return
assert False
assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
if cat > 0 or chunk > 0:
last_id = cat * 2 + chunk - 1
if cat < 2:
assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
else:
assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
if chunk == 1 and cat < 2:
if stage < self.n_stage - 1:
_fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
assert end_time[_fa_id] >= 0
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
if chunk == 0 and cat < 2:
if stage > 0:
_fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
_id = self.get_id(cat, chunk, stage, _cnt)
if count[stage][0] > 0:
stage_bubble[stage] += _tmp - _no_bubble
end_time[_id] = _tmp
cur_time[stage] = _tmp
mem[stage] += self.fbw_mem[cat]
# noinspection PyTypeChecker
schedule[stage].append((cat, chunk, _cnt))
if cat == 1:
pending_w[stage].append((2, chunk, _cnt))
count[stage][cat * 2 + chunk] += 1
for i in range(self.n_stage):
put(0, 0, i)
for i in range(self.n_stage - 1, -1, -1):
if i == self.n_stage - 1:
put(0, 1, i)
continue
tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
while (
mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem
and cur_time[i] + self.fbw_cost[0] <= tmp
and count[i][0] < self.n_micro
):
for j in range(i + 1):
put(0, 0, j)
put(0, 1, i)
iter_chunk_ = 0
end_tmp = 0
for i in range(self.n_stage):
if i == 0:
end_tmp = cur_time[0] + self.fbw_cost[1]
continue
tmp = end_tmp + self.c_cost
while (
count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]
or count[i][1] <= count[i - 1][1] < self.n_micro
):
for j in range(self.n_stage - 1, i - 1, -1):
if count[j][iter_chunk_] < self.n_micro:
put(0, iter_chunk_, j)
iter_chunk_ = 1 - iter_chunk_
for _ in range(2 * self.n_micro):
# check mem before putting b
for i in range(self.n_stage):
while mem[i] + self.fbw_mem[1] > self.max_mem:
assert len(pending_w[i]) > 0
put_w(i)
b0_ranks, b1_ranks = [], []
for i in range(self.n_stage):
if count[i][3] >= count[i][2]:
b0_ranks.append(i)
elif i == self.n_stage - 1:
b1_ranks.append(i)
else:
fa_id = self.get_id(1, 1, i + 1, count[i][3])
if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
b1_ranks.append(i)
else:
b0_ranks.append(i)
b_ranks = []
# put b1
for i in reversed(b1_ranks):
b_ranks.append((i, 1))
# put b0
for i in b0_ranks:
b_ranks.append((i, 0))
for i, _chunk_ in b_ranks:
fa_id = -1
if _chunk_ == 1 and i < self.n_stage - 1:
fa_id = self.get_id(1, 1, i + 1, count[i][3])
if _chunk_ == 0 and i > 0:
fa_id = self.get_id(1, 0, i - 1, count[i][2])
while (
len(pending_w[i]) > 0
and fa_id >= 0
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
):
# fill the bubble
put_w(i)
if (
len(pending_w[i]) > 0
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
):
if _chunk_ == 1:
put_w(i)
elif fill_b:
put_w(i)
put(1, _chunk_, i)
# put f
for i in range(self.n_stage):
if count[i][1] >= self.n_micro:
continue
put_item = None
if count[i][1] >= count[i][0]:
put_item = 0
elif i == self.n_stage - 1:
put_item = 1
else:
if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
put_item = 1
elif count[i][0] < self.n_micro:
if i == 0:
put_item = 0
elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
put_item = 0
if put_item is None:
continue
# check mem before putting f
while mem[i] + self.fbw_mem[0] > self.max_mem:
assert len(pending_w[i]) > 0
put_w(i)
fa_id = -1
if put_item == 0 and i > 0:
fa_id = self.get_id(0, 0, i - 1, count[i][0])
if put_item == 1 and i < self.n_stage - 1:
fa_id = self.get_id(0, 1, i + 1, count[i][1])
while (
len(pending_w[i]) > 0
and fa_id >= 0
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
):
# fill the bubble
put_w(i)
if (
len(pending_w[i]) > 0
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
):
if fill_f:
put_w(i)
put(0, put_item, i)
for i in range(self.n_stage):
while len(pending_w[i]) > 0:
put_w(i)
max_bubble = get_max_stage_bubble()
expected_time = sum(self.fbw_cost) * self.n_micro * 2
max_bubble / expected_time
if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
_schedule, _end_time, _max_bubble = self.try_v_schedule(
fill_f=fill_f,
fill_b=fill_b,
approved_bubble=stage_bubble,
)
if _max_bubble < max_bubble:
return _schedule, _end_time, _max_bubble
return schedule, end_time, max_bubble
def print_details(self, end_time, print_scaling=1):
for stage in range(self.n_stage):
stage_str = ["."] * int(max(end_time) / print_scaling)
for _cat in range(3):
for _chunk in range(2):
for _micro in range(self.n_micro):
_id = self.get_id(_cat, _chunk, stage, _micro)
if end_time[_id] < 0:
continue
end = int(end_time[_id] / print_scaling)
start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
for j in range(start, end):
if j == start or j == end - 1:
stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
elif j == start + 1:
if _micro >= 10:
stage_str[j] = str(_micro // 10)
else:
stage_str[j] = str(_micro)
elif j == start + 2 and _micro >= 10:
stage_str[j] = str(_micro % 10)
else:
stage_str[j] = "-"
_str = ""
for _c in stage_str:
_str += _c
print(_str)
def get_v_schedule(self, only_run_time=False):
schedule, end_time, max_bubble = None, None, None
expected_time = sum(self.fbw_cost) * self.n_micro * 2
for fill_b in [True, False]:
for fill_f in [True, False]:
_schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f)
if max_bubble is None or _max_bubble < max_bubble:
max_bubble = _max_bubble
schedule = _schedule
end_time = _end_time
if only_run_time:
return max_bubble + expected_time
max_bubble / (expected_time + max_bubble)
local_order = [[] for _ in range(self.n_stage)]
comm_id = {}
comm_id_counter = 0
post_validation_time = 0
for i in range(self.n_stage - 1, -1, -1):
pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
post_validation_time = max(
post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost
)
# post_validation_time = 0
for it in ["RECV_", "SEND_", ""]:
if i == 0 and it == "SEND_":
continue
if i == self.n_stage - 1 and it == "RECV_":
continue
# stage_ = i - 1 if it == "RECV_" else i
stage_ = i
local_order[stage_].append(
ScheduledNode(
type=it + "POST_VALIDATION",
chunk=0,
stage=stage_,
minibatch=0,
start_time=post_validation_time,
completion_time=post_validation_time,
)
)
comm_id[local_order[stage_][-1]] = comm_id_counter
comm_id_counter += 1
for i in range(self.n_stage):
for _cat_, _chunk_, _micro_ in schedule[i]:
complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
local_order[i].append(
ScheduledNode(
type="FBW"[_cat_],
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
stage=i,
minibatch=_micro_,
start_time=complete_time - self.fbw_cost[_cat_],
completion_time=complete_time,
)
)
if _cat_ == 2: # no communication for W
continue
cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
def communicate(send_recv, stage_):
# noinspection PyTypeChecker
local_order[stage_].append(
ScheduledNode(
type=send_recv + cat_str,
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
stage=stage_,
minibatch=_micro_,
start_time=complete_time,
completion_time=complete_time,
)
)
comm_id[local_order[stage_][-1]] = comm_id_counter
if _chunk_ == 1 and i > 0:
communicate("SEND_", i)
communicate("RECV_", i - 1)
if _chunk_ == 0 and i < self.n_stage - 1:
communicate("SEND_", i)
communicate("RECV_", i + 1)
comm_id_counter += 1
for rank in range(self.n_stage):
# For nodes with the same timestamp on the same stage, communication will be prioritized.
def even_breaker(x: ScheduledNode):
# Compute nodes are always delayed.
if x.type in ["F", "B", "W"]:
return comm_id_counter
# For comm nodes, order by their unique comm id
return comm_id[x]
local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x))))
# If a recv with intersects with previous computation, reorder them so that recv
# is executed before computation and hence can be overlapped.
for i in range(len(local_order[rank])):
if (
i > 0
and local_order[rank][i - 1].type in {"F", "B", "W"}
and local_order[rank][i].type.startswith("RECV")
and "POST_VALIDATION" not in local_order[rank][i].type
and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time
):
local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
local_order_with_rollback = [[] for _ in range(self.n_stage)]
for rank in range(self.n_stage):
rollback_comm = set()
if rank > 0:
for node in local_order[rank - 1]:
if node.type == "POST_VALIDATION":
break
if node.type == "SEND_FORWARD":
assert node.chunk == 0
rollback_comm.add(node.minibatch)
for node in local_order[rank]:
if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
rollback = True
rollback_comm.remove(node.minibatch)
else:
rollback = False
local_order_with_rollback[rank].append(
ScheduledNode(
type=node.type,
chunk=node.chunk,
stage=node.stage,
minibatch=node.minibatch,
start_time=node.start_time,
completion_time=node.completion_time,
rollback=rollback,
)
)
assert len(rollback_comm) == 0
return local_order_with_rollback

View File

@ -0,0 +1,958 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_flatten, tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.pipeline.weight_grad_store import WeightGradStore
from ._utils import (
clone,
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
release_tensor_data,
require_grad,
retain_grad,
to_device,
)
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
stage_manager: PipelineStageManager,
schedule: List[ScheduledNode],
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
):
super().__init__(stage_manager)
# batch info
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
self.batch: Any
self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int]
self.schedules = schedule
# TODO: optim post valid
self.do_post_validation = False
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
# check send_tensor_metadata, send_grad_metadata
# pp4 as sample, we should follow this meta strategy
# send_tensor_meta(fwd) send_grad_meta(bwd)
# chunk0 | chunk1 chunk0 | chunk 1
# stage 0 T | F F | T
# stage 1 T | T T | T
# stage 2 T | T T | T
# stage 3 F | T F | T
if stage_manager.is_first_stage(ignore_chunk=True):
self.send_tensor_metadata = [True, False]
self.send_grad_metadata = [False, True]
elif stage_manager.is_last_stage(ignore_chunk=True):
self.send_tensor_metadata = [False, True]
self.send_grad_metadata = [True, False]
else:
self.send_tensor_metadata = [True, True]
self.send_grad_metadata = [True, True]
# meta cache buffer
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
self.grad_metadata_recv = [None, None]
# P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
# init communication map
self.communication_map = {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}
# init buffer
self._free_buffers()
def _free_buffers(self):
# free local buffer
# two dim array, first dim is the model chunk, second dim is the microbatch queue
# x & y buffer for schedule b
self.input_tensors = [[], []]
self.output_tensors = [[], []]
# y & dy buffer for schedule w
self.output_tensors_dw = [[], []]
self.output_tensors_grad_dw = [[], []]
# buffer for communication
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_forward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_backward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
# y buffer for local send fwd
self.local_send_forward_buffer = []
# dy buffer for local send bwd
self.local_send_backward_buffer = []
# wait pp buffer
self.wait_handles = []
def assert_buffer_empty(self):
# assert buffer is empty at end
assert len(self.input_tensors[0]) == 0
assert len(self.input_tensors[1]) == 0
assert len(self.output_tensors[0]) == 0
assert len(self.output_tensors[1]) == 0
assert len(self.output_tensors_dw[0]) == 0
assert len(self.output_tensors_dw[1]) == 0
assert len(self.output_tensors_grad_dw[0]) == 0
assert len(self.output_tensors_grad_dw[1]) == 0
assert len(self.send_forward_buffer[0]) == 0
assert len(self.send_forward_buffer[1]) == 0
assert len(self.recv_forward_buffer[0]) == 0
assert len(self.recv_forward_buffer[1]) == 0
assert len(self.send_backward_buffer[0]) == 0
assert len(self.send_backward_buffer[1]) == 0
assert len(self.recv_backward_buffer[0]) == 0
assert len(self.recv_backward_buffer[1]) == 0
assert len(self.local_send_forward_buffer) == 0
assert len(self.local_send_backward_buffer) == 0
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Args:
microbatch_id (int): the current model chunk idx.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
microbatch_id (int): the current microbatch idx
forward (bool): if is the forward process
Returns:
int: The model chunk idx of the input microbatch_id
"""
assert (
microbatch_id < self.num_microbatch * self.num_model_chunks
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
# Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
################
# chunk = 0 & is_first_stage
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
################
# chunk = 0 & not is_first_stage
# Recv y from PREV_rank as input
#################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles
else:
################
# chunk = 1 & is_last_stage
# do nothing; cause u get y from local_send_forward_buffer in schedule f
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
# return None, []
return []
################
# chunk = 1 & not is_last_stage
# recv y from NEXT_rank as input
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
# bwd chunk0 is right V;
################
# chunk = 0 & is_last_stage
# do nothing; Already get dy from local_send_backward_buffer in schedule b
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return []
################
# chunk = 0 & not is_last_stage
# Recv bwd from next stage;
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles
else:
# bwd chunk1 is left V;
################
# chunk = 1 & is_first_stage
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
################
# chunk = 1 & not first stage
# recv_backward recv bwd from prev stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
################
# chunk = 0 && is_last_stage
# do nothing; hold y on local_send_forward_buffer
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 0 && not is_last_stage
# self.comm.send_forward send y to NEXT stage
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(
output_object=output_tensor,
next_rank=next_rank,
send_metadata=self.send_tensor_metadata[model_chunk_id],
)
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
else:
################
# chunk = 1 && is_first_stage
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 1 && not is_first_stage
# self.comm.send_forward send y to PREV stage
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
)
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
# bwd chunk0 is right V;
################
# chunk = 0 && is_first_stage
# do nothing; cause u are the first chunk in first stage; bwd end
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 0 && not is_first_stage
# Send dx to PREV stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
)
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
# bwd chunk1 is left V;
else:
################
# chunk = 1 && is_last_stage
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 1 && not is_last_stage
# Send dx to NEXT stage;
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
)
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
def forward_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
micro_batch: Optional[dict],
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_b_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
# micro_batch: Optional[dict],
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
) -> Optional[dict]:
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Optional[dict]: dx.
"""
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj. No need retain_grad microbatch
if input_obj is not None:
tree_map(retain_grad, input_obj)
# x, y, dy list for backward_by_grad; Type: list[tensor];
input_obj_ = []
output_obj_ = []
output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
# For loss backward; output_obj is loss; output_obj_grad should be None
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj)
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else:
input_obj_, _ = tree_flatten(input_obj)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
try:
ctx = optimizer.no_sync()
except AttributeError:
ctx = model_chunk.no_sync()
with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
# inputs=input_obj_,
retain_graph=False,
)
# Format output_obj_grad
input_obj_grad = dict()
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def backward_w_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
):
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Nothing need to return; we only calculate dw then update w;
"""
# calculate bwd w step ; only dw = x*dy;
# y, dy list for w backward_by_grad; Type: list[tensor];
output_obj_ = []
output_obj_grad_ = []
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss;
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None
else:
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=list(model_chunk.parameters()),
retain_graph=False,
)
def schedule_f(
self,
scheduled_node,
model_chunk: torch.nn.ModuleList,
model_chunk_id: int,
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
):
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Nothing.
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = None # (tensor, wait_handle)
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
else:
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_obj = self.local_send_forward_buffer.pop(0)
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
# Here, let input_obj.requires_grad_()
# if input_obj is not None:
if not isinstance(input_obj, torch.Tensor):
tree_map(require_grad, input_obj)
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
# tree_map(torch.Tensor.requires_grad_, micro_batch)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
micro_batch=micro_batch,
input_obj=input_obj,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# Step3:
# 3-1:detach output; detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
# detach output
detached_output_obj = tree_map(detach, output_obj)
# 3-2 clone detached_output_obj
detached_output_obj = tree_map(clone, detached_output_obj)
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not release_tensor_data bwd LOSS
pass
else:
# release_tensor_data output
tree_map(release_tensor_data, output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj)
else:
self.output_tensors[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(detached_output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else: # chunk 1
# is first stage; end of fwd; do nothing
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
def schedule_b(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
)
# Step3: send bwd
if model_chunk_id == 0:
# do nothing; end of bwd;
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
# save input_object_grad to send_backward_buffer
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
else:
# send to local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(input_object_grad)
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
WeightGradStore.flush(chunk=model_chunk_id)
def schedule_w(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
):
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
WeightGradStore.pop(chunk=model_chunk_id)
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
for it in range(len(self.schedules)):
scheduled_node = self.schedules[it]
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs Zerobubble schedule, with communication between pipeline stages.
"""
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
elif scheduled_node.type == "B":
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
elif scheduled_node.type == "W":
self.schedule_w(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# wait here to ensure all communication is done
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
self.assert_buffer_empty()
return result

View File

@ -26,6 +26,7 @@ class PipelineStageManager:
pg_mesh: ProcessGroupMesh, pg_mesh: ProcessGroupMesh,
pipeline_axis: int, pipeline_axis: int,
enable_interleave: bool = False, enable_interleave: bool = False,
use_zbv: bool = False,
num_model_chunks: int = 1, num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None, num_layers_per_stage: Optional[List[int]] = None,
) -> None: ) -> None:
@ -49,6 +50,7 @@ class PipelineStageManager:
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
self.is_interleave = enable_interleave self.is_interleave = enable_interleave
self.use_zbv = use_zbv
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model # for shardformer, hold stage indices of model
@ -85,6 +87,16 @@ class PipelineStageManager:
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = [] stage_indices = []
if self.use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices
for model_chunk in range(num_model_chunks): for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
@ -123,6 +135,10 @@ class PipelineStageManager:
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
if not self.is_interleave or ignore_chunk: if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1 return self.stage == self.num_stages - 1
else:
# use zero bubble pipeline
if self.use_zbv:
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
else: else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@ -207,7 +223,6 @@ class PipelineStageManager:
# calculate the num_layers per stage # calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers # deal with the rest layers
if remainder > 0: if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 start_position = (num_stages * num_model_chunks) // 2 - remainder // 2

View File

@ -0,0 +1,32 @@
import queue
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")

View File

@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLin
__all__ = [ __all__ = [
"Embedding1D", "Embedding1D",
"VocabParallelEmbedding1D", "VocabParallelEmbedding1D",
"LinearWithGradAccum",
"Linear1D_Col", "Linear1D_Col",
"Linear1D_Row", "Linear1D_Row",
"GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Col",

View File

@ -1,7 +1,11 @@
import functools
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.pipeline.weight_grad_store import WeightGradStore
from .utils import is_share_sp_tp from .utils import is_share_sp_tp
try: try:
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None: if bias is not None:
output = F.linear(input_, weight, bias) output = F.linear(input_, weight, bias)
else: else:
@ -143,6 +148,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias: if use_bias:
@ -164,9 +176,35 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None: if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None grad_weight = None
@ -175,6 +213,18 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
@ -182,6 +232,104 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce and not fp8_communication: if ctx.async_grad_allreduce and not fp8_communication:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""
@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
@ -966,12 +1114,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
) )
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return LinearWithAsyncCommunication.apply( return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
) )
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_gather_forward_reducescatter_backward( def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
): ):

View File

@ -27,13 +27,155 @@ from ._operation import (
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
linear_with_grad_accum,
reduce_forward, reduce_forward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import PaddingParallelModule, ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset, is_share_sp_tp from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["Linear1D_Col", "Linear1D_Row"] __all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
class LinearWithGradAccum(ParallelModule):
r"""Linear layer with no parallelism.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
linear_1d = LinearWithGradAccum(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_grad_accum(
input_parallel,
self.weight,
bias,
False,
use_zbv=self.use_zbv,
)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class Linear1D_Col(ParallelModule): class Linear1D_Col(ParallelModule):
@ -81,6 +223,7 @@ class Linear1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(weight=weight, bias_=bias_, **kwargs) super().__init__(weight=weight, bias_=bias_, **kwargs)
@ -95,6 +238,7 @@ class Linear1D_Col(ParallelModule):
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -209,9 +353,14 @@ class Linear1D_Col(ParallelModule):
) )
else: else:
output_parallel = linear_with_async_comm( output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication input_parallel,
self.weight,
bias,
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
) )
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward(
@ -267,6 +416,7 @@ class Linear1D_Row(ParallelModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1, stream_chunk_num: int = 1,
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
): ):
super().__init__() super().__init__()
@ -282,6 +432,7 @@ class Linear1D_Row(ParallelModule):
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")

View File

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
@ -191,7 +191,6 @@ class LlamaPipelineForwards:
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
) )
assert num_ckpt_layers <= end_idx - start_idx assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

View File

@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
moe_dp_group: ProcessGroup, moe_dp_group: ProcessGroup,
ep_group: ProcessGroup, ep_group: ProcessGroup,
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
): ):
assert tp_group is not None assert tp_group is not None
assert moe_dp_group is not None assert moe_dp_group is not None
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
self.ep_rank = dist.get_rank(ep_group) self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group self.ep_group = ep_group
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if self.num_experts % self.ep_size != 0: if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
if self.tp_group.size() > 1: if self.tp_group.size() > 1:
for expert in held_experts: for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module( expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
expert.w3 = Linear1D_Col.from_native_module( expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
expert.w2 = Linear1D_Row.from_native_module( expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
for p in self.experts.parameters(): for p in self.experts.parameters():
@ -379,7 +381,6 @@ class MixtralPipelineForwards:
output_router_logits, output_router_logits,
use_cache, use_cache,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -399,6 +400,7 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None: if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
@ -512,7 +514,6 @@ class MixtralPipelineForwards:
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n

View File

@ -75,6 +75,8 @@ class BertPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -97,6 +99,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -105,6 +108,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -113,6 +117,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -125,6 +130,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -138,6 +144,7 @@ class BertPolicy(Policy):
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -146,6 +153,97 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[BertEmbeddings] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
]
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_intermediate_forward(),
},
policy=policy,
target_key=BertIntermediate,
)
elif use_zbv:
policy[BertLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(

View File

@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding, PaddingEmbedding,
PaddingLMHead, PaddingLMHead,
RMSNorm, RMSNorm,
@ -60,6 +61,8 @@ class LlamaPolicy(Policy):
else: else:
norm_cls = RMSNorm norm_cls = RMSNorm
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
@ -102,7 +105,7 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=LlamaModel, target_key=LlamaModel,
) )
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
num_q_heads % tp_size == 0 num_q_heads % tp_size == 0
@ -126,37 +129,135 @@ class LlamaPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.gate_proj", suffix="mlp.gate_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.up_proj", suffix="mlp.up_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.down_proj", suffix="mlp.down_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
# not enable tp, replace layer to LinearWithGradAccum
elif use_zbv:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
], ],
) )
@ -261,9 +362,10 @@ class LlamaPolicy(Policy):
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True): if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm) held_layers.append(module.norm)
else: else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
@ -353,11 +455,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True): if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
return []
llama_model = self.model.model llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if ( if (
@ -379,7 +485,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification from transformers import LlamaForSequenceClassification
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification # add a new item for sequence classification
new_item = { new_item = {
@ -391,12 +499,32 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
kwargs=dict( kwargs=dict(
gather_output=True, gather_output=True,
fp8_communication=self.shard_config.fp8_communication, fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
), ),
) )
] ]
) )
} }
policy.update(new_item) policy.update(new_item)
# enable tp, replace layer to LinearWithGradAccum
elif use_zbv:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score",
target_module=LinearWithGradAccum,
kwargs=dict(
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
# to be confirmed # to be confirmed
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default
@ -411,7 +539,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True): if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score) held_layers.append(self.model.score)
return held_layers return held_layers

View File

@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding, PaddingEmbedding,
PaddingLMHead, PaddingLMHead,
VocabParallelEmbedding1D, VocabParallelEmbedding1D,
@ -62,6 +63,8 @@ class MistralPolicy(Policy):
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn( warnings.warn(
@ -90,6 +93,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -97,6 +101,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -104,6 +109,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -111,6 +117,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -118,6 +125,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -125,6 +133,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -132,6 +141,68 @@ class MistralPolicy(Policy):
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
elif use_zbv:
policy[MistralDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
], ],

View File

@ -7,9 +7,18 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer import (
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D FusedRMSNorm,
from colossalai.shardformer.layer.linear import Linear1D_Row Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
VocabParallelEmbedding1D,
)
# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
# from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import ( from colossalai.shardformer.modeling.mixtral import (
EPMixtralSparseMoeBlock, EPMixtralSparseMoeBlock,
MixtralPipelineForwards, MixtralPipelineForwards,
@ -52,6 +61,7 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size tp_size = self.shard_config.tensor_parallel_size
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# modified for both SP and TP # modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads num_q_heads = self.model.config.num_attention_heads
@ -124,31 +134,92 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="block_sparse_moe.gate", suffix="block_sparse_moe.gate",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
], ],
) )
elif use_zbv:
policy[MixtralDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="block_sparse_moe.gate",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
if embedding_cls is not None: if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
@ -179,6 +250,7 @@ class MixtralPolicy(Policy):
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
) )
], ],
@ -258,6 +330,23 @@ class MixtralPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.stage_indices = stage_indices
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
# for zbv, when is_first_stage (last fwd), we append norm
# for interleaved, when is_last_stage (last fwd), we also append norm
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
@ -265,7 +354,6 @@ class MixtralPolicy(Policy):
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.norm) held_layers.append(module.norm)
return held_layers return held_layers
@ -297,6 +385,7 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy): class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO: assign pg mesh from plugin to all modules # TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm # add a new item for causal lm
@ -306,9 +395,29 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
) )
] ],
)
}
policy.update(new_item)
elif use_zbv:
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
],
) )
} }
policy.update(new_item) policy.update(new_item)
@ -327,7 +436,9 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(): if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
@ -353,6 +464,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification from transformers import MixtralForSequenceClassification
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification # add a new item for sequence classification
@ -362,7 +474,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="score", suffix="score",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
) )
] ]
) )

View File

@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict, Tuple from typing import Any, List, OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -78,9 +78,7 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype) v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2) assert_close_loose(v1, v2)
else: else:
if isinstance(v1, Tuple) and not isinstance(v2, Tuple): assert v1 == v2, f"{v1} not equals to {v2}"
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

View File

@ -1,6 +1,5 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json import json
import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -12,6 +11,26 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()} _TYPES_INV = {v: k for k, v in _TYPES.items()}
import io
from torch.distributed.distributed_c10d import _pickler, _unpickler
def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage).to(device)
return byte_tensor
def _tensor_to_object(tensor, tensor_size):
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
@dataclass @dataclass
@ -28,49 +47,68 @@ class PreparedData:
offset: int offset: int
def flatten_dict(nested_dict, parent_key="", separator="^"): def _cast_to_tensor(obj):
""" if isinstance(obj, torch.Tensor):
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. return obj
return _object_to_tensor(obj, "cpu")
nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed. def _cast_to_object(tensor: torch.Tensor):
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary." return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
"""
items = []
for k, v in nested_dict.items(): def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k) flat_dict = {}
if isinstance(v, dict): non_tensor_keys = []
items.extend(flatten_dict(v, new_key, separator).items()) if "state" in state_dict:
# 3-level dict
states = state_dict["state"]
else: else:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v # 2-level dict, usually for optimizer state dict shard
items.append((new_key, v)) states = state_dict
return dict(items) for idx, d in states.items():
for k, v in d.items():
nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
non_tensor_keys.append(nested_key)
flat_dict[nested_key] = _cast_to_tensor(v)
if "param_groups" in state_dict:
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
non_tensor_keys.append("param_groups")
if len(non_tensor_keys) > 0:
metadata = {"non_tensor_keys": non_tensor_keys}
else:
metadata = None
return flat_dict, metadata
def unflatten_dict(flattened_dict, separator="^"): def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
""" state_dict = {}
Restore a flattened dictionary back to a multi-level nested dictionary. if metadata is not None:
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
else:
non_tensor_keys = []
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
if "param_groups" in flat_dict:
# 3-level dict
state_dict["param_groups"] = flat_dict.pop("param_groups")
state_dict["state"] = {}
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict
flattened_dict: The flattened dictionary. for k, v in flat_dict.items():
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary. parts = k.split(seperator)
""" assert len(parts) == 3 and parts[0] == "state"
nested_dict = {} idx = int(parts[1])
for key, value in flattened_dict.items(): key = parts[2]
keys = key.split(separator) if idx not in states:
try: states[idx] = {}
keys[0] = int(keys[0]) states[idx][key] = v
except ValueError:
warnings.warn(f"{key[0]} can't convert to integer")
d = nested_dict
for part in keys[:-1]:
if part not in d:
d[part] = {}
d = d[part]
assert isinstance(value, torch.Tensor)
d[keys[-1]] = value
return nested_dict return state_dict
def prepare( def prepare(
@ -124,10 +162,8 @@ def save(
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def save_nested( def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None flatten_data, metadata = _flatten_optim_state_dict(state_dict)
) -> None:
flatten_data = flatten_dict(state_dict)
save(f_writer, flatten_data, metadata) save(f_writer, flatten_data, metadata)
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f: with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
state_dict_load = load_file(checkpoint_path) state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load) state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
if metadata is None:
return state_dict return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict

View File

@ -351,7 +351,7 @@ class GeminiDDP(ModelWrapper):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
@staticmethod @staticmethod

View File

@ -300,12 +300,14 @@ class GeminiOptimizer(OptimizerWrapper):
loss = self.mix_precision_mixin.pre_backward(loss) loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss) self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): def backward_by_grad(
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
):
# This function is called except the last stage of pipeline parallel # This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank # It receives the scaled grad from the previous rank
# No need to scale the grad again # No need to scale the grad again
# Need to unscale when optimizing # Need to unscale when optimizing
grad = self.mix_precision_mixin.pre_backward_by_grad(grad) grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
self.module.backward_by_grad(tensor, grad) self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self): def _maybe_move_fp32_params(self):

View File

@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# torch.optim.Optimizer methods # torch.optim.Optimizer methods
################################ ################################
def backward(self, loss, retain_graph=False): def backward(self, loss, inputs=None, retain_graph=False):
assert not ( assert not (
self._partition_grads and not self.require_grad_sync self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible" ), "ZeRO2(partition_grads) and no_sync are not compatible"
@ -458,7 +458,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ctx = nullcontext() if self._backward_context is None else self._backward_context() ctx = nullcontext() if self._backward_context is None else self._backward_context()
with ctx: with ctx:
loss.backward(retain_graph=retain_graph) loss.backward(inputs=inputs, retain_graph=retain_graph)
if not self.require_grad_sync: if not self.require_grad_sync:
return return
@ -469,14 +469,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._overlap_communication: if self._overlap_communication:
get_accelerator().synchronize() get_accelerator().synchronize()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
assert not ( assert not (
self._partition_grads and not self.require_grad_sync self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad) torch.autograd.backward(
tensor,
grad,
inputs=inputs,
retain_graph=retain_graph,
)
if not self.require_grad_sync: if not self.require_grad_sync:
return return

View File

@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -39,6 +40,7 @@ MODEL_CONFIGS = {
), ),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8), "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096), "7b": LlamaConfig(max_position_embeddings=4096),
# "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096),
"13b": LlamaConfig( "13b": LlamaConfig(
hidden_size=5120, hidden_size=5120,
intermediate_size=13824, intermediate_size=13824,
@ -91,7 +93,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument( parser.add_argument(
@ -106,6 +108,7 @@ def main():
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument( parser.add_argument(
"--sp_mode", "--sp_mode",
@ -137,6 +140,11 @@ def main():
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
use_empty_init = True use_empty_init = True
if args.plugin == "gemini": if args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -210,6 +218,24 @@ def main():
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "3d": elif args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f * 1.5,
b_mem=mem_b * 1.5,
w_mem=mem_w * 1.5,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
@ -227,6 +253,7 @@ def main():
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":
@ -242,7 +269,7 @@ def main():
microbatch_size=args.mbs, microbatch_size=args.mbs,
initial_scale=2**8, initial_scale=2**8,
precision="bf16", precision="bf16",
overlap_p2p=args.overlap, overlap_p2p=args.overlap_p2p,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
) )
@ -260,6 +287,7 @@ def main():
config = MODEL_CONFIGS[args.config] config = MODEL_CONFIGS[args.config]
else: else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
dataset = RandomDataset( dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
@ -319,7 +347,7 @@ def main():
args.profile, args.profile,
args.ignore_steps, args.ignore_steps,
1, # avoid creating massive log files 1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys, nsys=args.nsys,
) as prof: ) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
@ -334,7 +362,11 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: if args.pp_style == "zbv":
if coordinator.is_master():
print(f"Step {step} loss: {loss}")
else:
if coordinator.is_last_process():
print(f"Step {step} loss: {loss}") print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -11,6 +11,7 @@ from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai import colossalai
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -85,7 +87,7 @@ def main():
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument( parser.add_argument(
@ -129,7 +131,29 @@ def main():
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
if args.plugin == "3d": if args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
ep_size=args.ep, ep_size=args.ep,
tp_size=args.tp, tp_size=args.tp,
@ -143,11 +167,13 @@ def main():
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
num_microbatches=args.batch_size // args.mbs,
precision="bf16", precision="bf16",
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs, **hybrid_kwargs,
) )
else: else:
@ -183,8 +209,10 @@ def main():
with init_ctx: with init_ctx:
model = MixtralForCausalLM(config=config).to(torch.bfloat16) model = MixtralForCausalLM(config=config).to(torch.bfloat16)
# if args.grad_checkpoint:
# model.gradient_checkpointing_enable()
if args.grad_checkpoint: if args.grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model_numel = get_model_numel(model) model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
@ -229,6 +257,10 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if args.pp_style == "zbv":
if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}") print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()

View File

@ -21,11 +21,16 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float: def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1: if world_size == 1:
return x return x
# BUG: RuntimeError: Invalid scalar type when use dist.all_reduce(tensor, group=gloo_group)
# # Use CPU tensor to avoid OOM/weird NCCl error
# gloo_group = dist.new_group(backend="gloo")
# tensor = torch.tensor([x], device="cpu")
# dist.all_reduce(tensor, group=gloo_group)
# tensor = tensor / world_size
# return tensor.item()
# Use CPU tensor to avoid OOM/weird NCCl error tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
gloo_group = dist.new_group(backend="gloo") dist.all_reduce(tensor)
tensor = torch.tensor([x], device="cpu")
dist.all_reduce(tensor, group=gloo_group)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()

View File

@ -1,9 +1,9 @@
import tempfile import tempfile
from copy import deepcopy
import torch import torch
from safetensors.torch import load_file
from colossalai.utils.safetensors import load_flat, save_nested from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
try: try:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
@ -11,17 +11,29 @@ except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device
def test_save_load(): def test_save_load():
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = { optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "state": {
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, 0: {
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "step": torch.tensor(1.0),
} "exp_avg": torch.rand((1024, 1024)),
# group_dict = {"param_groups": [0, 1, 2]} "exp_avg_sq": torch.rand((1024, 1024)),
group_dict = { },
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [ "param_groups": [
{ {
"lr": 0.001, "lr": 0.001,
@ -94,22 +106,26 @@ def test_save_load():
61, 61,
], ],
} }
] ],
} }
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path) load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"] check_state_dict_equal(load_state_dict, optimizer_state_dict)
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict) optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
check_state_dict_equal(group_dict, group) f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = { model_state_dict = {
"module.weight0": torch.rand((1024, 1024)), "module.weight0": torch.rand((1024, 1024)),
@ -118,10 +134,20 @@ def test_save_load():
} }
model_saved_path = f"{tempdir}/save_model.safetensors" model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict) save(f_writer, model_state_dict)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
load_state_dict = load_flat(model_saved_path) check_state_dict_equal(model_state_dict, load_state_dict)
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict) check_state_dict_equal(model_state_dict, load_state_dict)

View File

@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
self.is_interleave = False self.is_interleave = False
self.num_layers_per_stage = None self.num_layers_per_stage = None
self.num_model_chunks = 1 self.num_model_chunks = 1
self.use_zbv = False
@property @property
def num_stages(self): def num_stages(self):

View File

@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
self.is_interleave = False self.is_interleave = False
self.num_layers_per_stage = None self.num_layers_per_stage = None
self.num_model_chunks = 1 self.num_model_chunks = 1
self.use_zbv = False
@property @property
def num_stages(self): def num_stages(self):

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,8 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad) assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
# Weight grad is None before we do WeightGradStore pop
assert linear_base.weight.grad is None
# after WeightGradStore pop (dw computation complete), we assert weight grad
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_1d_row(lazy_init, seq_parallel_mode) check_linear_1d_row(lazy_init, seq_parallel_mode)
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
def check_dist_linear(rank, world_size, port): def check_dist_linear(rank, world_size, port):

View File

@ -310,10 +310,18 @@ def check_output_hidden_state(
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): if stage_manager:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
elif stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
else:
sharded_hidden_state = sharded_output.last_hidden_state
# Check if the output sequence is gathered before cross entropy # Check if the output sequence is gathered before cross entropy
if shard_config is not None: if shard_config is not None:
@ -388,7 +396,6 @@ def get_grad_tensors_for_check(
pass pass
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = { grad_to_check[suffix] = {
"org_grad": org_grad.float(), "org_grad": org_grad.float(),
"shard_grad": shard_grad.float(), "shard_grad": shard_grad.float(),

View File

@ -7,6 +7,7 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
if enable_gradient_checkpointing: if enable_gradient_checkpointing:
# org_model.gradient_checkpointing_enable() # org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
@ -112,12 +113,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): check_flag = False
if (
(stage_manager is None)
or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True))
or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True))
):
check_flag = True
if check_flag:
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "LlamaModel": if org_model.__class__.__name__ == "LlamaModel":
check_output_hidden_state( check_output_hidden_state(
org_output, org_output,
@ -274,6 +281,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
def run_llama_test(test_config): def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
if test_config.get("pp_style", None) == "zbv":
mem_f = 34 * 32 + 5 * 4 * 16
mem_w = -32 * 32
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=test_config["pp_size"],
n_micro=test_config["num_microbatches"],
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
test_config["scheduler_nodes"] = scheduler_nodes
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue continue