mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
Merge branch 'main' into ckpt_api
This commit is contained in:
commit
fa0318dba5
16
README.md
16
README.md
@ -25,24 +25,20 @@
|
||||
|
||||
</div>
|
||||
|
||||
## GPU Cloud HPC-AI.COM Coming!!
|
||||
## Get Started with Colossal-AI Without Setup
|
||||
|
||||
For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
|
||||
Plus, when you refer a friend, you’ll receive 20% cashback or compute credits equal to 100% of their top-up!
|
||||
Access high-end, on-demand compute for your research instantly—no setup needed.
|
||||
|
||||
Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
|
||||
Don’t miss this incredible opportunity to accelerate your AI projects!
|
||||
Sign up now and get $10 in credits!
|
||||
|
||||
Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!
|
||||
|
||||
Special Bonuses:
|
||||
Limited Academic Bonuses:
|
||||
|
||||
* Top up $1,000 and receive 300 credits
|
||||
* Top up $500 and receive 100 credits
|
||||
|
||||
<div align="center">
|
||||
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" />
|
||||
<a href="https://hpc-ai.com/?utm_source=github&utm_medium=social&utm_campaign=promotion-colossalai">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2.gif" width="850" />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
|
@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
|
||||
dtype: torch.dtype
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
|
||||
"""Called before backward.
|
||||
|
||||
Args:
|
||||
|
@ -86,13 +86,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
group["params"] = master_params
|
||||
self._current_grad_norm: Optional[float] = None
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||
loss = self.mixed_precision.pre_backward(loss)
|
||||
loss.backward(*args, **kwargs)
|
||||
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||
tensor.backward(grad)
|
||||
torch.autograd.backward(
|
||||
tensors=tensor,
|
||||
grad_tensors=grad,
|
||||
inputs=inputs,
|
||||
retain_graph=retain_graph,
|
||||
)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
for p in self.working_to_master_map.keys():
|
||||
|
@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
||||
growth_interval=growth_interval,
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
||||
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
|
||||
scaled_loss = self.scale_loss(loss)
|
||||
scaled_loss.backward(*args, **kwargs)
|
||||
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
def step(self, *args, **kwargs) -> Optional[float]:
|
||||
out = self.scaler.step(self.optim, *args, **kwargs)
|
||||
|
@ -28,7 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
@ -296,7 +296,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
self._current_grad_norm: Optional[float] = None
|
||||
super().__init__(optim)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||
r"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
@ -315,7 +315,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
with self.model._hook_context():
|
||||
super().backward(loss, *args, **kwargs)
|
||||
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
@ -324,7 +324,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
@ -341,7 +341,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
@ -525,7 +525,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
max_norm=max_norm,
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||
r"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
@ -543,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
"""
|
||||
# Call the superclass backward method to compute gradients.
|
||||
with self.model._hook_context():
|
||||
super().backward(loss, *args, **kwargs)
|
||||
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
@ -552,7 +552,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
@ -568,7 +568,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
@ -785,7 +785,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
else:
|
||||
return
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
def backward(self, loss, inputs=None, retain_graph=False):
|
||||
"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
@ -801,7 +801,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, retain_graph)
|
||||
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
@ -810,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
@ -826,7 +826,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward_by_grad method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
@ -1030,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
scheduler_nodes: List = None,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
@ -1048,6 +1049,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
||||
assert (
|
||||
not pp_style == "zbv" or scheduler_nodes is not None
|
||||
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||
if enable_sequence_parallelism:
|
||||
self.sequence_parallelism_mode = (
|
||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||
@ -1109,29 +1113,39 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.scheduler = None
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||
assert (
|
||||
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||
), "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert (
|
||||
self.zero_stage <= 1
|
||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
if pp_style == "zbv":
|
||||
self.logger.warning(
|
||||
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
|
||||
)
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||
use_zbv=(pp_style == "zbv"),
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||
self.schedule = InterleavedSchedule(
|
||||
self.scheduler = InterleavedSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
@ -1141,13 +1155,21 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.scheduler = OneForwardOneBackwardSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
stage_manager=self.stage_manager,
|
||||
schedule=scheduler_nodes,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if sequence_parallelism_mode == "ring_attn":
|
||||
@ -1263,7 +1285,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
|
||||
# Replace with distributed implementation if exists
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
||||
self.logger.warning(
|
||||
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||
@ -1278,6 +1299,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.dp_size == 1 and self.pp_size == 1
|
||||
)
|
||||
# sync gradients across DP * SP ranks
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
@ -1380,7 +1402,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
with ctx, model._hook_context():
|
||||
outputs = self.schedule.forward_backward_step(
|
||||
outputs = self.scheduler.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
|
@ -141,8 +141,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
save_nested(f_writer, state_dict)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
fp=open(checkpoint_file_path, "wb", buffering=0),
|
||||
n_entries=self.N_WRITE_ENTRIES,
|
||||
backend="pthread",
|
||||
)
|
||||
save_nested(f_writer, shard)
|
||||
self.async_writers.append(f_writer)
|
||||
|
@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||
@ -212,6 +213,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
scheduler_nodes: List = None,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
@ -285,12 +287,17 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.scheduler = None
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||
assert (
|
||||
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||
), "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
@ -300,14 +307,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
use_zbv=(pp_style == "zbv"),
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||
self.schedule = InterleavedSchedule(
|
||||
self.scheduler = InterleavedSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
@ -316,12 +324,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.scheduler = OneForwardOneBackwardSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=scheduler_nodes,
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -61,7 +61,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
|
@ -702,7 +702,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
complete_state_dict.update(_state_dict)
|
||||
if use_async:
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
writer = AsyncFileWriter(
|
||||
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
self.async_writers.append(writer)
|
||||
|
@ -311,7 +311,7 @@ def async_save_state_dict_shards(
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
|
||||
writers.append(writer)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
|
@ -49,14 +49,31 @@ class OptimizerWrapper:
|
||||
"""
|
||||
self.optim.zero_grad(*args, **kwargs)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||
"""
|
||||
Performs a backward pass on the loss.
|
||||
"""
|
||||
loss.backward(*args, **kwargs)
|
||||
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||
"""
|
||||
Performs a backward pass for dx or dw,
|
||||
for dx, we only calculate dx = w*dy here
|
||||
for dw, we only calculate dw = x*dy here
|
||||
|
||||
Args:
|
||||
tensor (Tensor): y or loss of current chunk;
|
||||
grad_tensors (Tensor): dy of current chunk;
|
||||
input_obj (Tensor): for dx, input_obj is x of current chunk;
|
||||
for dw, input_obj is w of current chunk;
|
||||
retain_graph (bool): default to be True, we retain graph in backward_b
|
||||
"""
|
||||
torch.autograd.backward(
|
||||
tensors=tensor,
|
||||
grad_tensors=grad,
|
||||
inputs=inputs,
|
||||
retain_graph=retain_graph,
|
||||
)
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
|
@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
|
||||
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
super().load_state_dict(state_dict)
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
if "step" in state and isinstance(state["step"], torch.Tensor):
|
||||
state["step"] = int(state["step"].item())
|
||||
|
||||
def torch_adam_update(
|
||||
self,
|
||||
data,
|
||||
|
@ -1,11 +1,12 @@
|
||||
from .p2p import PipelineP2PCommunication
|
||||
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
|
||||
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
__all__ = [
|
||||
"PipelineSchedule",
|
||||
"OneForwardOneBackwardSchedule",
|
||||
"InterleavedSchedule",
|
||||
"ZeroBubbleVPipeScheduler",
|
||||
"PipelineP2PCommunication",
|
||||
"PipelineStageManager",
|
||||
]
|
||||
|
@ -432,7 +432,6 @@ def _communicate(
|
||||
overlap_p2p=overlap_p2p,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
|
||||
if metadata_recv is not None:
|
||||
assert isinstance(metadata_recv, P2PMetadata)
|
||||
tree_spec = metadata_recv.tree_spec
|
||||
|
@ -1,9 +1,11 @@
|
||||
from .base import PipelineSchedule
|
||||
from .interleaved_pp import InterleavedSchedule
|
||||
from .one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from .zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
|
||||
__all__ = [
|
||||
"PipelineSchedule",
|
||||
"OneForwardOneBackwardSchedule",
|
||||
"InterleavedSchedule",
|
||||
"ZeroBubbleVPipeScheduler",
|
||||
]
|
||||
|
@ -137,6 +137,16 @@ def retain_grad(x: Any) -> None:
|
||||
x.retain_grad()
|
||||
|
||||
|
||||
def require_grad(x: Any) -> None:
|
||||
"""Call require_grad on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor) and not x.requires_grad:
|
||||
x.requires_grad_()
|
||||
|
||||
|
||||
def detach(x: Any) -> Any:
|
||||
"""Call detach() on a tensor.
|
||||
|
||||
@ -151,6 +161,34 @@ def detach(x: Any) -> Any:
|
||||
return x
|
||||
|
||||
|
||||
def clone(x: Any) -> Any:
|
||||
"""Call clone() on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The cloned object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
|
||||
def release_tensor_data(x: Any) -> Any:
|
||||
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The deallocate .data object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.data.untyped_storage().resize_(0)
|
||||
return x
|
||||
|
||||
|
||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||
"""Merge micro batches into a batch.
|
||||
|
||||
|
449
colossalai/pipeline/schedule/v_schedule.py
Normal file
449
colossalai/pipeline/schedule/v_schedule.py
Normal file
@ -0,0 +1,449 @@
|
||||
# Refer from Zero Bubble Pipeline Parallelism.
|
||||
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
|
||||
# Paper: https://arxiv.org/abs/2401.10241
|
||||
# The following applies to all files unless otherwise noted:
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class ScheduledNode:
|
||||
type: str
|
||||
chunk: int
|
||||
stage: int
|
||||
minibatch: int
|
||||
start_time: int = 0
|
||||
completion_time: int = 0
|
||||
rollback: bool = False
|
||||
|
||||
|
||||
class PipelineGraph(object):
|
||||
"""PipelineGraph"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_stage,
|
||||
n_micro,
|
||||
f_cost,
|
||||
b_cost,
|
||||
w_cost,
|
||||
c_cost,
|
||||
f_mem,
|
||||
b_mem,
|
||||
w_mem,
|
||||
max_mem=None,
|
||||
):
|
||||
self.n_node = 6 * n_stage * n_micro
|
||||
self.n_stage = n_stage
|
||||
self.n_micro = n_micro
|
||||
self.f_cost = f_cost
|
||||
self.b_cost = b_cost
|
||||
self.w_cost = w_cost
|
||||
self.c_cost = c_cost
|
||||
self.f_mem = f_mem
|
||||
self.b_mem = b_mem
|
||||
self.w_mem = w_mem
|
||||
self.fbw_cost = [f_cost, b_cost, w_cost]
|
||||
self.fbw_mem = [f_mem, b_mem, w_mem]
|
||||
self.max_mem = max_mem or f_mem * self.n_stage * 2
|
||||
|
||||
def get_id(self, cat, chunk, stage, micro):
|
||||
return (
|
||||
cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro
|
||||
)
|
||||
|
||||
def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
|
||||
count = []
|
||||
for i in range(self.n_stage):
|
||||
count.append([0] * 6)
|
||||
|
||||
end_time = [-1] * self.n_node
|
||||
cur_time = [0] * self.n_stage
|
||||
mem = [0] * self.n_stage
|
||||
stage_bubble = [0] * self.n_stage
|
||||
pending_w = [deque() for _ in range(self.n_stage)]
|
||||
schedule = [[] for _ in range(self.n_stage)]
|
||||
stage_str = [" " * i for i in range(self.n_stage)]
|
||||
|
||||
if approved_bubble is None:
|
||||
approved_bubble = [-1] * self.n_stage
|
||||
max_approved_bubble = max(approved_bubble)
|
||||
|
||||
def get_max_stage_bubble(stage=-1):
|
||||
max_stage_bubble = 0
|
||||
for bb in stage_bubble:
|
||||
max_stage_bubble = max(max_stage_bubble, bb)
|
||||
if stage >= 0:
|
||||
max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
|
||||
return max_stage_bubble
|
||||
|
||||
def put_w(stage):
|
||||
assert len(pending_w[stage]) > 0
|
||||
_, chunk_, _ = pending_w[stage].popleft()
|
||||
put(2, chunk_, stage)
|
||||
|
||||
def put(cat, chunk, stage, assert_cnt=True):
|
||||
_tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
|
||||
_cnt = count[stage][cat * 2 + chunk]
|
||||
# assert _cnt < self.n_micro
|
||||
if _cnt >= self.n_micro:
|
||||
if not assert_cnt:
|
||||
stage_str[stage] += " "
|
||||
cur_time[stage] = _tmp # TODO
|
||||
return
|
||||
assert False
|
||||
assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
|
||||
stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
|
||||
if cat > 0 or chunk > 0:
|
||||
last_id = cat * 2 + chunk - 1
|
||||
if cat < 2:
|
||||
assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
|
||||
else:
|
||||
assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
|
||||
if chunk == 1 and cat < 2:
|
||||
if stage < self.n_stage - 1:
|
||||
_fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
|
||||
assert end_time[_fa_id] >= 0
|
||||
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
|
||||
if chunk == 0 and cat < 2:
|
||||
if stage > 0:
|
||||
_fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
|
||||
assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
|
||||
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
|
||||
_id = self.get_id(cat, chunk, stage, _cnt)
|
||||
if count[stage][0] > 0:
|
||||
stage_bubble[stage] += _tmp - _no_bubble
|
||||
end_time[_id] = _tmp
|
||||
cur_time[stage] = _tmp
|
||||
mem[stage] += self.fbw_mem[cat]
|
||||
# noinspection PyTypeChecker
|
||||
schedule[stage].append((cat, chunk, _cnt))
|
||||
if cat == 1:
|
||||
pending_w[stage].append((2, chunk, _cnt))
|
||||
count[stage][cat * 2 + chunk] += 1
|
||||
|
||||
for i in range(self.n_stage):
|
||||
put(0, 0, i)
|
||||
for i in range(self.n_stage - 1, -1, -1):
|
||||
if i == self.n_stage - 1:
|
||||
put(0, 1, i)
|
||||
continue
|
||||
tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
|
||||
while (
|
||||
mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem
|
||||
and cur_time[i] + self.fbw_cost[0] <= tmp
|
||||
and count[i][0] < self.n_micro
|
||||
):
|
||||
for j in range(i + 1):
|
||||
put(0, 0, j)
|
||||
put(0, 1, i)
|
||||
iter_chunk_ = 0
|
||||
end_tmp = 0
|
||||
for i in range(self.n_stage):
|
||||
if i == 0:
|
||||
end_tmp = cur_time[0] + self.fbw_cost[1]
|
||||
continue
|
||||
tmp = end_tmp + self.c_cost
|
||||
while (
|
||||
count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]
|
||||
or count[i][1] <= count[i - 1][1] < self.n_micro
|
||||
):
|
||||
for j in range(self.n_stage - 1, i - 1, -1):
|
||||
if count[j][iter_chunk_] < self.n_micro:
|
||||
put(0, iter_chunk_, j)
|
||||
iter_chunk_ = 1 - iter_chunk_
|
||||
|
||||
for _ in range(2 * self.n_micro):
|
||||
# check mem before putting b
|
||||
for i in range(self.n_stage):
|
||||
while mem[i] + self.fbw_mem[1] > self.max_mem:
|
||||
assert len(pending_w[i]) > 0
|
||||
put_w(i)
|
||||
b0_ranks, b1_ranks = [], []
|
||||
for i in range(self.n_stage):
|
||||
if count[i][3] >= count[i][2]:
|
||||
b0_ranks.append(i)
|
||||
elif i == self.n_stage - 1:
|
||||
b1_ranks.append(i)
|
||||
else:
|
||||
fa_id = self.get_id(1, 1, i + 1, count[i][3])
|
||||
if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
|
||||
b1_ranks.append(i)
|
||||
else:
|
||||
b0_ranks.append(i)
|
||||
b_ranks = []
|
||||
# put b1
|
||||
for i in reversed(b1_ranks):
|
||||
b_ranks.append((i, 1))
|
||||
# put b0
|
||||
for i in b0_ranks:
|
||||
b_ranks.append((i, 0))
|
||||
for i, _chunk_ in b_ranks:
|
||||
fa_id = -1
|
||||
if _chunk_ == 1 and i < self.n_stage - 1:
|
||||
fa_id = self.get_id(1, 1, i + 1, count[i][3])
|
||||
if _chunk_ == 0 and i > 0:
|
||||
fa_id = self.get_id(1, 0, i - 1, count[i][2])
|
||||
while (
|
||||
len(pending_w[i]) > 0
|
||||
and fa_id >= 0
|
||||
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
|
||||
):
|
||||
# fill the bubble
|
||||
put_w(i)
|
||||
if (
|
||||
len(pending_w[i]) > 0
|
||||
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
|
||||
):
|
||||
if _chunk_ == 1:
|
||||
put_w(i)
|
||||
elif fill_b:
|
||||
put_w(i)
|
||||
put(1, _chunk_, i)
|
||||
|
||||
# put f
|
||||
for i in range(self.n_stage):
|
||||
if count[i][1] >= self.n_micro:
|
||||
continue
|
||||
put_item = None
|
||||
if count[i][1] >= count[i][0]:
|
||||
put_item = 0
|
||||
elif i == self.n_stage - 1:
|
||||
put_item = 1
|
||||
else:
|
||||
if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
|
||||
put_item = 1
|
||||
elif count[i][0] < self.n_micro:
|
||||
if i == 0:
|
||||
put_item = 0
|
||||
elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
|
||||
put_item = 0
|
||||
if put_item is None:
|
||||
continue
|
||||
# check mem before putting f
|
||||
while mem[i] + self.fbw_mem[0] > self.max_mem:
|
||||
assert len(pending_w[i]) > 0
|
||||
put_w(i)
|
||||
fa_id = -1
|
||||
if put_item == 0 and i > 0:
|
||||
fa_id = self.get_id(0, 0, i - 1, count[i][0])
|
||||
if put_item == 1 and i < self.n_stage - 1:
|
||||
fa_id = self.get_id(0, 1, i + 1, count[i][1])
|
||||
while (
|
||||
len(pending_w[i]) > 0
|
||||
and fa_id >= 0
|
||||
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
|
||||
):
|
||||
# fill the bubble
|
||||
put_w(i)
|
||||
if (
|
||||
len(pending_w[i]) > 0
|
||||
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
|
||||
):
|
||||
if fill_f:
|
||||
put_w(i)
|
||||
put(0, put_item, i)
|
||||
|
||||
for i in range(self.n_stage):
|
||||
while len(pending_w[i]) > 0:
|
||||
put_w(i)
|
||||
|
||||
max_bubble = get_max_stage_bubble()
|
||||
expected_time = sum(self.fbw_cost) * self.n_micro * 2
|
||||
max_bubble / expected_time
|
||||
if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
|
||||
_schedule, _end_time, _max_bubble = self.try_v_schedule(
|
||||
fill_f=fill_f,
|
||||
fill_b=fill_b,
|
||||
approved_bubble=stage_bubble,
|
||||
)
|
||||
if _max_bubble < max_bubble:
|
||||
return _schedule, _end_time, _max_bubble
|
||||
return schedule, end_time, max_bubble
|
||||
|
||||
def print_details(self, end_time, print_scaling=1):
|
||||
for stage in range(self.n_stage):
|
||||
stage_str = ["."] * int(max(end_time) / print_scaling)
|
||||
for _cat in range(3):
|
||||
for _chunk in range(2):
|
||||
for _micro in range(self.n_micro):
|
||||
_id = self.get_id(_cat, _chunk, stage, _micro)
|
||||
if end_time[_id] < 0:
|
||||
continue
|
||||
end = int(end_time[_id] / print_scaling)
|
||||
start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
|
||||
for j in range(start, end):
|
||||
if j == start or j == end - 1:
|
||||
stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
|
||||
elif j == start + 1:
|
||||
if _micro >= 10:
|
||||
stage_str[j] = str(_micro // 10)
|
||||
else:
|
||||
stage_str[j] = str(_micro)
|
||||
elif j == start + 2 and _micro >= 10:
|
||||
stage_str[j] = str(_micro % 10)
|
||||
else:
|
||||
stage_str[j] = "-"
|
||||
_str = ""
|
||||
for _c in stage_str:
|
||||
_str += _c
|
||||
print(_str)
|
||||
|
||||
def get_v_schedule(self, only_run_time=False):
|
||||
schedule, end_time, max_bubble = None, None, None
|
||||
expected_time = sum(self.fbw_cost) * self.n_micro * 2
|
||||
for fill_b in [True, False]:
|
||||
for fill_f in [True, False]:
|
||||
_schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f)
|
||||
if max_bubble is None or _max_bubble < max_bubble:
|
||||
max_bubble = _max_bubble
|
||||
schedule = _schedule
|
||||
end_time = _end_time
|
||||
if only_run_time:
|
||||
return max_bubble + expected_time
|
||||
max_bubble / (expected_time + max_bubble)
|
||||
local_order = [[] for _ in range(self.n_stage)]
|
||||
comm_id = {}
|
||||
comm_id_counter = 0
|
||||
post_validation_time = 0
|
||||
for i in range(self.n_stage - 1, -1, -1):
|
||||
pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
|
||||
post_validation_time = max(
|
||||
post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost
|
||||
)
|
||||
# post_validation_time = 0
|
||||
for it in ["RECV_", "SEND_", ""]:
|
||||
if i == 0 and it == "SEND_":
|
||||
continue
|
||||
if i == self.n_stage - 1 and it == "RECV_":
|
||||
continue
|
||||
# stage_ = i - 1 if it == "RECV_" else i
|
||||
stage_ = i
|
||||
local_order[stage_].append(
|
||||
ScheduledNode(
|
||||
type=it + "POST_VALIDATION",
|
||||
chunk=0,
|
||||
stage=stage_,
|
||||
minibatch=0,
|
||||
start_time=post_validation_time,
|
||||
completion_time=post_validation_time,
|
||||
)
|
||||
)
|
||||
comm_id[local_order[stage_][-1]] = comm_id_counter
|
||||
comm_id_counter += 1
|
||||
for i in range(self.n_stage):
|
||||
for _cat_, _chunk_, _micro_ in schedule[i]:
|
||||
complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
|
||||
local_order[i].append(
|
||||
ScheduledNode(
|
||||
type="FBW"[_cat_],
|
||||
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
|
||||
stage=i,
|
||||
minibatch=_micro_,
|
||||
start_time=complete_time - self.fbw_cost[_cat_],
|
||||
completion_time=complete_time,
|
||||
)
|
||||
)
|
||||
if _cat_ == 2: # no communication for W
|
||||
continue
|
||||
cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
|
||||
|
||||
def communicate(send_recv, stage_):
|
||||
# noinspection PyTypeChecker
|
||||
local_order[stage_].append(
|
||||
ScheduledNode(
|
||||
type=send_recv + cat_str,
|
||||
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
|
||||
stage=stage_,
|
||||
minibatch=_micro_,
|
||||
start_time=complete_time,
|
||||
completion_time=complete_time,
|
||||
)
|
||||
)
|
||||
comm_id[local_order[stage_][-1]] = comm_id_counter
|
||||
|
||||
if _chunk_ == 1 and i > 0:
|
||||
communicate("SEND_", i)
|
||||
communicate("RECV_", i - 1)
|
||||
if _chunk_ == 0 and i < self.n_stage - 1:
|
||||
communicate("SEND_", i)
|
||||
communicate("RECV_", i + 1)
|
||||
comm_id_counter += 1
|
||||
for rank in range(self.n_stage):
|
||||
# For nodes with the same timestamp on the same stage, communication will be prioritized.
|
||||
def even_breaker(x: ScheduledNode):
|
||||
# Compute nodes are always delayed.
|
||||
if x.type in ["F", "B", "W"]:
|
||||
return comm_id_counter
|
||||
# For comm nodes, order by their unique comm id
|
||||
return comm_id[x]
|
||||
|
||||
local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x))))
|
||||
# If a recv with intersects with previous computation, reorder them so that recv
|
||||
# is executed before computation and hence can be overlapped.
|
||||
for i in range(len(local_order[rank])):
|
||||
if (
|
||||
i > 0
|
||||
and local_order[rank][i - 1].type in {"F", "B", "W"}
|
||||
and local_order[rank][i].type.startswith("RECV")
|
||||
and "POST_VALIDATION" not in local_order[rank][i].type
|
||||
and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time
|
||||
):
|
||||
local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
|
||||
|
||||
local_order_with_rollback = [[] for _ in range(self.n_stage)]
|
||||
for rank in range(self.n_stage):
|
||||
rollback_comm = set()
|
||||
if rank > 0:
|
||||
for node in local_order[rank - 1]:
|
||||
if node.type == "POST_VALIDATION":
|
||||
break
|
||||
if node.type == "SEND_FORWARD":
|
||||
assert node.chunk == 0
|
||||
rollback_comm.add(node.minibatch)
|
||||
for node in local_order[rank]:
|
||||
if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
|
||||
rollback = True
|
||||
rollback_comm.remove(node.minibatch)
|
||||
else:
|
||||
rollback = False
|
||||
local_order_with_rollback[rank].append(
|
||||
ScheduledNode(
|
||||
type=node.type,
|
||||
chunk=node.chunk,
|
||||
stage=node.stage,
|
||||
minibatch=node.minibatch,
|
||||
start_time=node.start_time,
|
||||
completion_time=node.completion_time,
|
||||
rollback=rollback,
|
||||
)
|
||||
)
|
||||
assert len(rollback_comm) == 0
|
||||
|
||||
return local_order_with_rollback
|
958
colossalai/pipeline/schedule/zero_bubble_pp.py
Normal file
958
colossalai/pipeline/schedule/zero_bubble_pp.py
Normal file
@ -0,0 +1,958 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
|
||||
from ._utils import (
|
||||
clone,
|
||||
detach,
|
||||
get_batch_size,
|
||||
get_micro_batch,
|
||||
merge_batch,
|
||||
model_forward,
|
||||
release_tensor_data,
|
||||
require_grad,
|
||||
retain_grad,
|
||||
to_device,
|
||||
)
|
||||
from .base import PipelineSchedule
|
||||
|
||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||
|
||||
|
||||
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
if wait_handles is not None:
|
||||
for req in wait_handles:
|
||||
req.wait()
|
||||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
stage_manager: PipelineStageManager,
|
||||
schedule: List[ScheduledNode],
|
||||
num_model_chunks: int,
|
||||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
):
|
||||
super().__init__(stage_manager)
|
||||
# batch info
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
self.num_model_chunks = num_model_chunks
|
||||
self.batch: Any
|
||||
self.batch_size: int
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
self.schedules = schedule
|
||||
# TODO: optim post valid
|
||||
self.do_post_validation = False
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
|
||||
# check send_tensor_metadata, send_grad_metadata
|
||||
# pp4 as sample, we should follow this meta strategy
|
||||
# send_tensor_meta(fwd) send_grad_meta(bwd)
|
||||
# chunk0 | chunk1 chunk0 | chunk 1
|
||||
# stage 0 T | F F | T
|
||||
# stage 1 T | T T | T
|
||||
# stage 2 T | T T | T
|
||||
# stage 3 F | T F | T
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata = [True, False]
|
||||
self.send_grad_metadata = [False, True]
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata = [False, True]
|
||||
self.send_grad_metadata = [True, False]
|
||||
else:
|
||||
self.send_tensor_metadata = [True, True]
|
||||
self.send_grad_metadata = [True, True]
|
||||
|
||||
# meta cache buffer
|
||||
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
||||
self.grad_metadata_recv = [None, None]
|
||||
|
||||
# P2P communication
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
|
||||
# init communication map
|
||||
self.communication_map = {
|
||||
"SEND_FORWARD": self.send_forward,
|
||||
"RECV_FORWARD": self.recv_forward,
|
||||
"SEND_BACKWARD": self.send_backward,
|
||||
"RECV_BACKWARD": self.recv_backward,
|
||||
}
|
||||
|
||||
# init buffer
|
||||
self._free_buffers()
|
||||
|
||||
def _free_buffers(self):
|
||||
# free local buffer
|
||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||
|
||||
# x & y buffer for schedule b
|
||||
self.input_tensors = [[], []]
|
||||
self.output_tensors = [[], []]
|
||||
|
||||
# y & dy buffer for schedule w
|
||||
self.output_tensors_dw = [[], []]
|
||||
self.output_tensors_grad_dw = [[], []]
|
||||
|
||||
# buffer for communication
|
||||
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||
self.recv_forward_buffer = [
|
||||
[],
|
||||
[],
|
||||
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||
self.recv_backward_buffer = [
|
||||
[],
|
||||
[],
|
||||
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||
|
||||
# y buffer for local send fwd
|
||||
self.local_send_forward_buffer = []
|
||||
# dy buffer for local send bwd
|
||||
self.local_send_backward_buffer = []
|
||||
|
||||
# wait pp buffer
|
||||
self.wait_handles = []
|
||||
|
||||
def assert_buffer_empty(self):
|
||||
# assert buffer is empty at end
|
||||
assert len(self.input_tensors[0]) == 0
|
||||
assert len(self.input_tensors[1]) == 0
|
||||
assert len(self.output_tensors[0]) == 0
|
||||
assert len(self.output_tensors[1]) == 0
|
||||
assert len(self.output_tensors_dw[0]) == 0
|
||||
assert len(self.output_tensors_dw[1]) == 0
|
||||
assert len(self.output_tensors_grad_dw[0]) == 0
|
||||
assert len(self.output_tensors_grad_dw[1]) == 0
|
||||
assert len(self.send_forward_buffer[0]) == 0
|
||||
assert len(self.send_forward_buffer[1]) == 0
|
||||
assert len(self.recv_forward_buffer[0]) == 0
|
||||
assert len(self.recv_forward_buffer[1]) == 0
|
||||
assert len(self.send_backward_buffer[0]) == 0
|
||||
assert len(self.send_backward_buffer[1]) == 0
|
||||
assert len(self.recv_backward_buffer[0]) == 0
|
||||
assert len(self.recv_backward_buffer[1]) == 0
|
||||
assert len(self.local_send_forward_buffer) == 0
|
||||
assert len(self.local_send_backward_buffer) == 0
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
Args:
|
||||
data_iter (Iterable): Data iterator.
|
||||
device (Optional[torch.device], optional): Target device. Defaults to None.
|
||||
"""
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
|
||||
if self.microbatch_size is None:
|
||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatch
|
||||
if self.num_microbatch is None:
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatch = self.batch_size // self.microbatch_size
|
||||
|
||||
if not self.forward_only:
|
||||
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
||||
assert self.batch_size == self.microbatch_size * self.num_microbatch
|
||||
|
||||
assert (
|
||||
self.num_microbatch % self.stage_manager.num_stages == 0
|
||||
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
||||
Args:
|
||||
microbatch_id (int): the current model chunk idx.
|
||||
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
|
||||
"""Helper method to get the model chunk ID given the iteration number.
|
||||
|
||||
Args:
|
||||
microbatch_id (int): the current microbatch idx
|
||||
forward (bool): if is the forward process
|
||||
|
||||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
assert (
|
||||
microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not is_forward:
|
||||
# Reverse order
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
prev_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 & is_first_stage
|
||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||
#################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 0 & not is_first_stage
|
||||
# Recv y from PREV_rank as input
|
||||
#################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(
|
||||
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
else:
|
||||
################
|
||||
# chunk = 1 & is_last_stage
|
||||
# do nothing; cause u get y from local_send_forward_buffer in schedule f
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
# return None, []
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 1 & not is_last_stage
|
||||
# recv y from NEXT_rank as input
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(
|
||||
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
# bwd chunk0 is right V;
|
||||
################
|
||||
# chunk = 0 & is_last_stage
|
||||
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 0 & not is_last_stage
|
||||
# Recv bwd from next stage;
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
else:
|
||||
# bwd chunk1 is left V;
|
||||
################
|
||||
# chunk = 1 & is_first_stage
|
||||
# do nothing; get loss from local
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 1 & not first stage
|
||||
# recv_backward recv bwd from prev stage;
|
||||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 && is_last_stage
|
||||
# do nothing; hold y on local_send_forward_buffer
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 0 && not is_last_stage
|
||||
# self.comm.send_forward send y to NEXT stage
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(
|
||||
output_object=output_tensor,
|
||||
next_rank=next_rank,
|
||||
send_metadata=self.send_tensor_metadata[model_chunk_id],
|
||||
)
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
else:
|
||||
################
|
||||
# chunk = 1 && is_first_stage
|
||||
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 1 && not is_first_stage
|
||||
# self.comm.send_forward send y to PREV stage
|
||||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(
|
||||
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
# bwd chunk0 is right V;
|
||||
################
|
||||
# chunk = 0 && is_first_stage
|
||||
# do nothing; cause u are the first chunk in first stage; bwd end
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 0 && not is_first_stage
|
||||
# Send dx to PREV stage;
|
||||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
# bwd chunk1 is left V;
|
||||
else:
|
||||
################
|
||||
# chunk = 1 && is_last_stage
|
||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 1 && not is_last_stage
|
||||
# Send dx to NEXT stage;
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
micro_batch: Optional[dict],
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
input_obj (Optional[dict]): x;
|
||||
criterion (Callable): loss function;
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_b_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
# micro_batch: Optional[dict],
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
) -> Optional[dict]:
|
||||
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
|
||||
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
|
||||
output_obj (Union[dict, torch.Tensor]): y.
|
||||
output_obj_grad (dict): dy.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: dx.
|
||||
"""
|
||||
# calculate bwd b step ; only dx = w*dy;
|
||||
|
||||
# Retain the grad on the input_obj. No need retain_grad microbatch
|
||||
if input_obj is not None:
|
||||
tree_map(retain_grad, input_obj)
|
||||
|
||||
# x, y, dy list for backward_by_grad; Type: list[tensor];
|
||||
input_obj_ = []
|
||||
output_obj_ = []
|
||||
output_obj_grad_ = []
|
||||
|
||||
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
||||
|
||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
assert output_obj_grad is None
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(output_obj_grad) # None
|
||||
|
||||
# For other chunk stage, use input_obj as input_obj_;
|
||||
else:
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
# filter item which is not torch.Tensor
|
||||
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||
|
||||
try:
|
||||
ctx = optimizer.no_sync()
|
||||
except AttributeError:
|
||||
ctx = model_chunk.no_sync()
|
||||
with ctx:
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
# inputs=input_obj_,
|
||||
retain_graph=False,
|
||||
)
|
||||
# Format output_obj_grad
|
||||
input_obj_grad = dict()
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
else:
|
||||
for k, v in input_obj.items():
|
||||
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def backward_w_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
|
||||
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||
output_obj (Union[dict, torch.Tensor]): y.
|
||||
output_obj_grad (dict): dy.
|
||||
|
||||
Returns:
|
||||
Nothing need to return; we only calculate dw then update w;
|
||||
"""
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
|
||||
# y, dy list for w backward_by_grad; Type: list[tensor];
|
||||
output_obj_ = []
|
||||
output_obj_grad_ = []
|
||||
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss;
|
||||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(None) # None
|
||||
else:
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
# filter item which is not torch.Tensor
|
||||
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
inputs=list(model_chunk.parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
||||
def schedule_f(
|
||||
self,
|
||||
scheduled_node,
|
||||
model_chunk: torch.nn.ModuleList,
|
||||
model_chunk_id: int,
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
):
|
||||
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
|
||||
|
||||
Args:
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
criterion (Callable): loss function;
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from microbatch
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = None # (tensor, wait_handle)
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
for h in input_obj[1]:
|
||||
h.wait()
|
||||
input_obj = input_obj[0]
|
||||
else:
|
||||
# is last stage; recv from local
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_obj = self.local_send_forward_buffer.pop(0)
|
||||
# not last stage; recv from next
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
for h in input_obj[1]:
|
||||
h.wait()
|
||||
input_obj = input_obj[0]
|
||||
# Here, let input_obj.requires_grad_()
|
||||
# if input_obj is not None:
|
||||
if not isinstance(input_obj, torch.Tensor):
|
||||
tree_map(require_grad, input_obj)
|
||||
|
||||
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
micro_batch=micro_batch,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not detach bwd LOSS
|
||||
pass
|
||||
else:
|
||||
# detach output
|
||||
detached_output_obj = tree_map(detach, output_obj)
|
||||
# 3-2 clone detached_output_obj
|
||||
detached_output_obj = tree_map(clone, detached_output_obj)
|
||||
|
||||
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not release_tensor_data bwd LOSS
|
||||
pass
|
||||
else:
|
||||
# release_tensor_data output
|
||||
tree_map(release_tensor_data, output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
|
||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
|
||||
# add output to send_fwd_buffer
|
||||
if model_chunk_id == 0: # chunk 0
|
||||
# is last stage; send to local_send_forward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(detached_output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
else: # chunk 1
|
||||
# is first stage; end of fwd; do nothing
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
):
|
||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||
|
||||
Args:
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
# Step1: recv bwd
|
||||
if model_chunk_id == 0:
|
||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
for h in output_tensor_grad[1]:
|
||||
h.wait()
|
||||
output_tensor_grad = output_tensor_grad[0]
|
||||
else:
|
||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad = None
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
for h in output_tensor_grad[1]:
|
||||
h.wait()
|
||||
output_tensor_grad = output_tensor_grad[0]
|
||||
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
optimizer=optimizer,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
)
|
||||
|
||||
# Step3: send bwd
|
||||
if model_chunk_id == 0:
|
||||
# do nothing; end of bwd;
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
# save input_object_grad to send_backward_buffer
|
||||
else:
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
else:
|
||||
# send to local_send_backward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_backward_buffer.append(input_object_grad)
|
||||
# send to next
|
||||
else:
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
WeightGradStore.flush(chunk=model_chunk_id)
|
||||
|
||||
def schedule_w(
|
||||
self,
|
||||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
):
|
||||
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
||||
|
||||
Args:
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
WeightGradStore.pop(chunk=model_chunk_id)
|
||||
|
||||
def run_forward_only(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
assert self.forward_only
|
||||
|
||||
# prepare batch
|
||||
self.load_batch(data_iter)
|
||||
|
||||
# prepare accum loss & output
|
||||
accum_loss = None
|
||||
|
||||
# reset accum loss at fwd end;
|
||||
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
# while we still have schedules_node in self.schedules
|
||||
for it in range(len(self.schedules)):
|
||||
scheduled_node = self.schedules[it]
|
||||
|
||||
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
communication_func(scheduled_node.chunk)
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||
"""
|
||||
# prepare batch
|
||||
self.load_batch(data_iter)
|
||||
|
||||
# prepare accum loss & output
|
||||
accum_loss = None
|
||||
|
||||
# reset accum loss at fwd end;
|
||||
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
# while we still have schedules_node in self.schedules
|
||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
|
||||
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
self.schedule_b(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
elif scheduled_node.type == "W":
|
||||
self.schedule_w(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
# wait here to ensure all communication is done
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
if self.forward_only:
|
||||
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
||||
else:
|
||||
result = self.run_forward_backward(
|
||||
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
self.assert_buffer_empty()
|
||||
return result
|
@ -26,6 +26,7 @@ class PipelineStageManager:
|
||||
pg_mesh: ProcessGroupMesh,
|
||||
pipeline_axis: int,
|
||||
enable_interleave: bool = False,
|
||||
use_zbv: bool = False,
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
@ -49,6 +50,7 @@ class PipelineStageManager:
|
||||
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
|
||||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
|
||||
self.is_interleave = enable_interleave
|
||||
self.use_zbv = use_zbv
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
# for shardformer, hold stage indices of model
|
||||
@ -85,6 +87,16 @@ class PipelineStageManager:
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
|
||||
stage_indices = []
|
||||
if self.use_zbv:
|
||||
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
|
||||
stage_indices.append(
|
||||
[
|
||||
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
|
||||
num_layers_per_stage_accumulated[2 * num_stages - stage],
|
||||
]
|
||||
)
|
||||
return stage_indices
|
||||
|
||||
for model_chunk in range(num_model_chunks):
|
||||
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
||||
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
||||
@ -124,7 +136,11 @@ class PipelineStageManager:
|
||||
if not self.is_interleave or ignore_chunk:
|
||||
return self.stage == self.num_stages - 1
|
||||
else:
|
||||
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||
# use zero bubble pipeline
|
||||
if self.use_zbv:
|
||||
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
|
||||
else:
|
||||
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||
|
||||
@property
|
||||
def num_stages(self) -> int:
|
||||
@ -207,7 +223,6 @@ class PipelineStageManager:
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
|
32
colossalai/pipeline/weight_grad_store.py
Normal file
32
colossalai/pipeline/weight_grad_store.py
Normal file
@ -0,0 +1,32 @@
|
||||
import queue
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
cache = []
|
||||
weight_grad_queue = [queue.Queue(), queue.Queue()]
|
||||
|
||||
@classmethod
|
||||
def put(cls, total_input, grad_output, weight, func):
|
||||
# func(total_input, grad_output, weight.main_grad)
|
||||
cls.cache.append((total_input, grad_output, weight, func))
|
||||
|
||||
@classmethod
|
||||
def flush(cls, chunk=0):
|
||||
cls.weight_grad_queue[chunk].put(cls.cache)
|
||||
cls.cache = []
|
||||
|
||||
@classmethod
|
||||
def pop(cls, chunk=0):
|
||||
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
|
||||
if cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
for total_input, grad_output, weight, func in stored_grads:
|
||||
if weight.grad is not None:
|
||||
func(total_input, grad_output, weight.grad)
|
||||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
else:
|
||||
grad_weight = func(total_input, grad_output)
|
||||
weight.grad = grad_weight
|
||||
else:
|
||||
raise Exception("Pop empty queue.")
|
@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
|
||||
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLin
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
"VocabParallelEmbedding1D",
|
||||
"LinearWithGradAccum",
|
||||
"Linear1D_Col",
|
||||
"Linear1D_Row",
|
||||
"GPT2FusedLinearConv1D_Col",
|
||||
|
@ -1,7 +1,11 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
|
||||
from .utils import is_share_sp_tp
|
||||
|
||||
try:
|
||||
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
ctx.use_zbv = use_zbv
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
@ -143,6 +148,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
@ -164,24 +176,160 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
if use_zbv:
|
||||
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||
if grad.dtype == torch.float32:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class LinearWithGradAccum(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer baseline (no tensor parallel version).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.use_zbv = use_zbv
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
output = F.linear(input_, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
bias.view(bias.shape)
|
||||
|
||||
total_input = input.contiguous()
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||
if grad.dtype == torch.float32:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
@ -966,12 +1114,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||
)
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
def linear_with_async_comm(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
|
||||
):
|
||||
return LinearWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
|
||||
)
|
||||
|
||||
|
||||
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
|
||||
):
|
||||
|
@ -27,13 +27,155 @@ from ._operation import (
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
linear_with_grad_accum,
|
||||
reduce_forward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
|
||||
|
||||
|
||||
class LinearWithGradAccum(ParallelModule):
|
||||
r"""Linear layer with no parallelism.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
use_zbv: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
linear_1d = LinearWithGradAccum(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_with_grad_accum(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
False,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
class Linear1D_Col(ParallelModule):
|
||||
@ -81,6 +223,7 @@ class Linear1D_Col(ParallelModule):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
@ -95,6 +238,7 @@ class Linear1D_Col(ParallelModule):
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
@ -209,9 +353,14 @@ class Linear1D_Col(ParallelModule):
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
fp8_communication=self.fp8_communication,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(
|
||||
@ -267,6 +416,7 @@ class Linear1D_Row(ParallelModule):
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -282,6 +432,7 @@ class Linear1D_Row(ParallelModule):
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -82,7 +82,7 @@ class LlamaPipelineForwards:
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
if inputs_embeds is None:
|
||||
@ -191,7 +191,6 @@ class LlamaPipelineForwards:
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.ep_group = ep_group
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if self.num_experts % self.ep_size != 0:
|
||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.w1 = Linear1D_Col.from_native_module(
|
||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||
)
|
||||
expert.w3 = Linear1D_Col.from_native_module(
|
||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||
)
|
||||
expert.w2 = Linear1D_Row.from_native_module(
|
||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
@ -379,7 +381,6 @@ class MixtralPipelineForwards:
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
@ -399,6 +400,7 @@ class MixtralPipelineForwards:
|
||||
|
||||
if output_router_logits and past_router_logits is not None:
|
||||
all_router_logits = past_router_logits + all_router_logits
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
@ -512,7 +514,6 @@ class MixtralPipelineForwards:
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
|
@ -75,6 +75,8 @@ class BertPolicy(Policy):
|
||||
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
@ -97,6 +99,7 @@ class BertPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -105,6 +108,7 @@ class BertPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -113,6 +117,7 @@ class BertPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -125,6 +130,7 @@ class BertPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -138,6 +144,7 @@ class BertPolicy(Policy):
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -146,6 +153,97 @@ class BertPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[BertEmbeddings] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_bert_intermediate_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertIntermediate,
|
||||
)
|
||||
|
||||
elif use_zbv:
|
||||
policy[BertLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
@ -60,6 +61,8 @@ class LlamaPolicy(Policy):
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
@ -102,7 +105,7 @@ class LlamaPolicy(Policy):
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
num_q_heads % tp_size == 0
|
||||
@ -126,37 +129,135 @@ class LlamaPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# not enable tp, replace layer to LinearWithGradAccum
|
||||
elif use_zbv:
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -261,9 +362,10 @@ class LlamaPolicy(Policy):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
@ -353,11 +455,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
||||
return []
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
@ -379,7 +485,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
@ -391,12 +499,32 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# enable tp, replace layer to LinearWithGradAccum
|
||||
elif use_zbv:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
@ -411,7 +539,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
|
@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
VocabParallelEmbedding1D,
|
||||
@ -62,6 +63,8 @@ class MistralPolicy(Policy):
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
@ -90,6 +93,7 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -97,6 +101,7 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -104,6 +109,7 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -111,6 +117,7 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -118,6 +125,7 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -125,6 +133,7 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@ -132,6 +141,68 @@ class MistralPolicy(Policy):
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[MistralDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -7,9 +7,18 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
|
||||
# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
# from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.modeling.mixtral import (
|
||||
EPMixtralSparseMoeBlock,
|
||||
MixtralPipelineForwards,
|
||||
@ -52,6 +61,7 @@ class MixtralPolicy(Policy):
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
tp_size = self.shard_config.tensor_parallel_size
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# modified for both SP and TP
|
||||
num_q_heads = self.model.config.num_attention_heads
|
||||
@ -124,31 +134,92 @@ class MixtralPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe.gate",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
elif use_zbv:
|
||||
policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe.gate",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
@ -179,6 +250,7 @@ class MixtralPolicy(Policy):
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
)
|
||||
],
|
||||
@ -258,14 +330,30 @@ class MixtralPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
stage_manager.stage_indices = stage_indices
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
# for zbv, when is_first_stage (last fwd), we append norm
|
||||
# for interleaved, when is_last_stage (last fwd), we also append norm
|
||||
held_layers.append(module.norm)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
return held_layers
|
||||
|
||||
|
||||
@ -297,6 +385,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for causal lm
|
||||
@ -306,9 +395,29 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
elif use_zbv:
|
||||
new_item = {
|
||||
MixtralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
@ -327,7 +436,9 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
@ -353,6 +464,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||
from transformers import MixtralForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
@ -362,7 +474,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, OrderedDict, Tuple
|
||||
from typing import Any, List, OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -78,9 +78,7 @@ def check_state_dict_equal(
|
||||
v1 = v1.to(v2.dtype)
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
|
||||
v2 = tuple(v2)
|
||||
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
||||
|
||||
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
|
@ -1,6 +1,5 @@
|
||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@ -12,6 +11,26 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
_TYPES_INV = {v: k for k, v in _TYPES.items()}
|
||||
import io
|
||||
|
||||
from torch.distributed.distributed_c10d import _pickler, _unpickler
|
||||
|
||||
|
||||
def _object_to_tensor(obj, device):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||
# Otherwise, it will casue 100X slowdown.
|
||||
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||
byte_tensor = torch.ByteTensor(byte_storage).to(device)
|
||||
return byte_tensor
|
||||
|
||||
|
||||
def _tensor_to_object(tensor, tensor_size):
|
||||
tensor = tensor.cpu()
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
return _unpickler(io.BytesIO(buf)).load()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -28,49 +47,68 @@ class PreparedData:
|
||||
offset: int
|
||||
|
||||
|
||||
def flatten_dict(nested_dict, parent_key="", separator="^"):
|
||||
"""
|
||||
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
|
||||
|
||||
nested_dict: The input nested dictionary.
|
||||
parent_key: The parent key currently being processed.
|
||||
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
|
||||
"""
|
||||
items = []
|
||||
for k, v in nested_dict.items():
|
||||
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, separator).items())
|
||||
else:
|
||||
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
|
||||
items.append((new_key, v))
|
||||
|
||||
return dict(items)
|
||||
def _cast_to_tensor(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj
|
||||
return _object_to_tensor(obj, "cpu")
|
||||
|
||||
|
||||
def unflatten_dict(flattened_dict, separator="^"):
|
||||
"""
|
||||
Restore a flattened dictionary back to a multi-level nested dictionary.
|
||||
def _cast_to_object(tensor: torch.Tensor):
|
||||
return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
|
||||
|
||||
flattened_dict: The flattened dictionary.
|
||||
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
|
||||
"""
|
||||
nested_dict = {}
|
||||
for key, value in flattened_dict.items():
|
||||
keys = key.split(separator)
|
||||
try:
|
||||
keys[0] = int(keys[0])
|
||||
except ValueError:
|
||||
warnings.warn(f"{key[0]} can't convert to integer")
|
||||
d = nested_dict
|
||||
for part in keys[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
assert isinstance(value, torch.Tensor)
|
||||
d[keys[-1]] = value
|
||||
|
||||
return nested_dict
|
||||
def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
|
||||
flat_dict = {}
|
||||
non_tensor_keys = []
|
||||
if "state" in state_dict:
|
||||
# 3-level dict
|
||||
states = state_dict["state"]
|
||||
else:
|
||||
# 2-level dict, usually for optimizer state dict shard
|
||||
states = state_dict
|
||||
|
||||
for idx, d in states.items():
|
||||
for k, v in d.items():
|
||||
nested_key = f"state{seperator}{idx}{seperator}{k}"
|
||||
if not isinstance(v, torch.Tensor):
|
||||
non_tensor_keys.append(nested_key)
|
||||
flat_dict[nested_key] = _cast_to_tensor(v)
|
||||
if "param_groups" in state_dict:
|
||||
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
|
||||
non_tensor_keys.append("param_groups")
|
||||
if len(non_tensor_keys) > 0:
|
||||
metadata = {"non_tensor_keys": non_tensor_keys}
|
||||
else:
|
||||
metadata = None
|
||||
return flat_dict, metadata
|
||||
|
||||
|
||||
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
|
||||
state_dict = {}
|
||||
if metadata is not None:
|
||||
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
|
||||
else:
|
||||
non_tensor_keys = []
|
||||
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
|
||||
if "param_groups" in flat_dict:
|
||||
# 3-level dict
|
||||
state_dict["param_groups"] = flat_dict.pop("param_groups")
|
||||
state_dict["state"] = {}
|
||||
states = state_dict["state"]
|
||||
else:
|
||||
# 2-level dict, usually for optimizer state dict shard
|
||||
states = state_dict
|
||||
|
||||
for k, v in flat_dict.items():
|
||||
parts = k.split(seperator)
|
||||
assert len(parts) == 3 and parts[0] == "state"
|
||||
idx = int(parts[1])
|
||||
key = parts[2]
|
||||
if idx not in states:
|
||||
states[idx] = {}
|
||||
states[idx][key] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def prepare(
|
||||
@ -124,10 +162,8 @@ def save(
|
||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||
|
||||
|
||||
def save_nested(
|
||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> None:
|
||||
flatten_data = flatten_dict(state_dict)
|
||||
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
|
||||
save(f_writer, flatten_data, metadata)
|
||||
|
||||
|
||||
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
|
||||
with safe_open(checkpoint_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
state_dict_load = load_file(checkpoint_path)
|
||||
state_dict = unflatten_dict(state_dict_load)
|
||||
if metadata is None:
|
||||
return state_dict
|
||||
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
|
||||
combined_state_dict = {"state": state_dict}
|
||||
combined_state_dict.update(metadata)
|
||||
return combined_state_dict
|
||||
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
|
||||
return state_dict
|
||||
|
@ -351,7 +351,7 @@ class GeminiDDP(ModelWrapper):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
|
||||
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
||||
|
||||
@staticmethod
|
||||
|
@ -300,12 +300,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
def backward_by_grad(
|
||||
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
|
||||
):
|
||||
# This function is called except the last stage of pipeline parallel
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
|
@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# torch.optim.Optimizer methods
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
def backward(self, loss, inputs=None, retain_graph=False):
|
||||
assert not (
|
||||
self._partition_grads and not self.require_grad_sync
|
||||
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
@ -458,7 +458,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
ctx = nullcontext() if self._backward_context is None else self._backward_context()
|
||||
with ctx:
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
loss.backward(inputs=inputs, retain_graph=retain_graph)
|
||||
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
@ -469,14 +469,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if self._overlap_communication:
|
||||
get_accelerator().synchronize()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||
assert not (
|
||||
self._partition_grads and not self.require_grad_sync
|
||||
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||
torch.autograd.backward(tensor, grad)
|
||||
torch.autograd.backward(
|
||||
tensor,
|
||||
grad,
|
||||
inputs=inputs,
|
||||
retain_graph=retain_graph,
|
||||
)
|
||||
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -39,6 +40,7 @@ MODEL_CONFIGS = {
|
||||
),
|
||||
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
# "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
@ -91,7 +93,7 @@ def main():
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument(
|
||||
@ -106,6 +108,7 @@ def main():
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
parser.add_argument(
|
||||
"--sp_mode",
|
||||
@ -137,6 +140,11 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
use_empty_init = True
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
@ -210,6 +218,24 @@ def main():
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||
mem_w = -32 * config.hidden_size
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=args.pp,
|
||||
n_micro=args.batch_size // args.mbs,
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f * 1.5,
|
||||
b_mem=mem_b * 1.5,
|
||||
w_mem=mem_w * 1.5,
|
||||
).get_v_schedule()
|
||||
else:
|
||||
scheduler_nodes = None
|
||||
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
@ -227,6 +253,7 @@ def main():
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
@ -242,7 +269,7 @@ def main():
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
overlap_p2p=args.overlap_p2p,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
@ -260,6 +287,7 @@ def main():
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
@ -319,7 +347,7 @@ def main():
|
||||
args.profile,
|
||||
args.ignore_steps,
|
||||
1, # avoid creating massive log files
|
||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
nsys=args.nsys,
|
||||
) as prof:
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
@ -334,8 +362,12 @@ def main():
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
if args.pp_style == "zbv":
|
||||
if coordinator.is_master():
|
||||
print(f"Step {step} loss: {loss}")
|
||||
else:
|
||||
if coordinator.is_last_process():
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -11,6 +11,7 @@ from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -85,7 +87,7 @@ def main():
|
||||
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument(
|
||||
@ -129,7 +131,29 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
if args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||
mem_w = -32 * config.hidden_size
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=args.pp,
|
||||
n_micro=args.batch_size // args.mbs,
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
).get_v_schedule()
|
||||
else:
|
||||
scheduler_nodes = None
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
ep_size=args.ep,
|
||||
tp_size=args.tp,
|
||||
@ -143,11 +167,13 @@ def main():
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
num_microbatches=args.batch_size // args.mbs,
|
||||
precision="bf16",
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
else:
|
||||
@ -183,8 +209,10 @@ def main():
|
||||
with init_ctx:
|
||||
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
||||
|
||||
# if args.grad_checkpoint:
|
||||
# model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
@ -229,8 +257,12 @@ def main():
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
if args.pp_style == "zbv":
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
else:
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -21,11 +21,16 @@ def divide(x: float, y: float) -> float:
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
# BUG: RuntimeError: Invalid scalar type when use dist.all_reduce(tensor, group=gloo_group)
|
||||
# # Use CPU tensor to avoid OOM/weird NCCl error
|
||||
# gloo_group = dist.new_group(backend="gloo")
|
||||
# tensor = torch.tensor([x], device="cpu")
|
||||
# dist.all_reduce(tensor, group=gloo_group)
|
||||
# tensor = tensor / world_size
|
||||
# return tensor.item()
|
||||
|
||||
# Use CPU tensor to avoid OOM/weird NCCl error
|
||||
gloo_group = dist.new_group(backend="gloo")
|
||||
tensor = torch.tensor([x], device="cpu")
|
||||
dist.all_reduce(tensor, group=gloo_group)
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from colossalai.utils.safetensors import load_flat, save_nested
|
||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
@ -11,17 +11,29 @@ except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
}
|
||||
# group_dict = {"param_groups": [0, 1, 2]}
|
||||
group_dict = {
|
||||
"state": {
|
||||
0: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
1: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
2: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
},
|
||||
"param_groups": [
|
||||
{
|
||||
"lr": 0.001,
|
||||
@ -94,22 +106,26 @@ def test_save_load():
|
||||
61,
|
||||
],
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
metadata = deepcopy(group_dict)
|
||||
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
|
||||
save_nested(f_writer, optimizer_state_dict, metadata)
|
||||
save_nested(f_writer, optimizer_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(optimizer_saved_path)
|
||||
state_dict = load_state_dict["state"]
|
||||
group = {"param_groups": load_state_dict["param_groups"]}
|
||||
check_state_dict_equal(optimizer_state_dict, state_dict)
|
||||
check_state_dict_equal(group_dict, group)
|
||||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||
|
||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict["state"])
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||
|
||||
model_state_dict = {
|
||||
"module.weight0": torch.rand((1024, 1024)),
|
||||
@ -118,10 +134,20 @@ def test_save_load():
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, model_state_dict)
|
||||
save(f_writer, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(model_saved_path)
|
||||
load_state_dict = load_file(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||
|
||||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
load_state_dict = load_file(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||
|
@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
self.num_model_chunks = 1
|
||||
self.use_zbv = False
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
|
@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
self.num_model_chunks = 1
|
||||
self.use_zbv = False
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
|
1085
tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Normal file
1085
tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,8 @@ from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = LinearWithGradAccum.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
assert linear_base.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_base(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = LinearWithGradAccum.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
assert linear_base.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_base(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
# Weight grad is None before we do WeightGradStore pop
|
||||
assert linear_base.weight.grad is None
|
||||
# after WeightGradStore pop (dw computation complete), we assert weight grad
|
||||
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
|
||||
WeightGradStore.pop(chunk=0)
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
|
||||
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
|
||||
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
|
||||
|
||||
|
||||
def check_dist_linear(rank, world_size, port):
|
||||
|
@ -310,8 +310,16 @@ def check_output_hidden_state(
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
if stage_manager:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
@ -388,7 +396,6 @@ def get_grad_tensors_for_check(
|
||||
pass
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
grad_to_check[suffix] = {
|
||||
"org_grad": org_grad.float(),
|
||||
"shard_grad": shard_grad.float(),
|
||||
|
@ -7,6 +7,7 @@ from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
# org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
@ -112,12 +113,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
check_flag = False
|
||||
if (
|
||||
(stage_manager is None)
|
||||
or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True))
|
||||
or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True))
|
||||
):
|
||||
check_flag = True
|
||||
if check_flag:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "LlamaModel":
|
||||
check_output_hidden_state(
|
||||
org_output,
|
||||
@ -274,6 +281,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
if test_config.get("pp_style", None) == "zbv":
|
||||
mem_f = 34 * 32 + 5 * 4 * 16
|
||||
mem_w = -32 * 32
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=test_config["pp_size"],
|
||||
n_micro=test_config["num_microbatches"],
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
).get_v_schedule()
|
||||
test_config["scheduler_nodes"] = scheduler_nodes
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
||||
continue
|
||||
|
Loading…
Reference in New Issue
Block a user