mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-14 23:55:32 +00:00
Merge pull request #6065 from duanjunwen/dev/zero_bubble
[Feat] Support zero bubble with shardformer input
This commit is contained in:
commit
8501202a35
@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
|
|||||||
from colossalai.nn.optimizer import cast_to_distributed
|
from colossalai.nn.optimizer import cast_to_distributed
|
||||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||||
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||||
@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
scheduler_nodes: List = None,
|
||||||
num_layers_per_stage: Optional[List[int]] = None,
|
num_layers_per_stage: Optional[List[int]] = None,
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert (
|
||||||
|
pp_style == "interleaved" or pp_style == "zbv"
|
||||||
|
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=pp_style == "interleaved",
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
)
|
)
|
||||||
@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
|
elif pp_style == "zbv":
|
||||||
|
self.schedule = ZeroBubbleVPipeScheduler(
|
||||||
|
schedule=scheduler_nodes,
|
||||||
|
stage_manager=self.stage_manager,
|
||||||
|
num_model_chunks=num_model_chunks,
|
||||||
|
num_microbatch=num_microbatches,
|
||||||
|
overlap_p2p=overlap_p2p,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
|
|||||||
x.retain_grad()
|
x.retain_grad()
|
||||||
|
|
||||||
|
|
||||||
|
def require_grad(x: Any) -> None:
|
||||||
|
"""Call require_grad on a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): Object to be called.
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor) and not x.requires_grad:
|
||||||
|
x.requires_grad_()
|
||||||
|
|
||||||
|
|
||||||
def detach(x: Any) -> Any:
|
def detach(x: Any) -> Any:
|
||||||
"""Call detach() on a tensor.
|
"""Call detach() on a tensor.
|
||||||
|
|
||||||
@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def clone(x: Any) -> Any:
|
||||||
|
"""Call clone() on a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): Object to be called.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The cloned object.
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.clone()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def release_tensor_data(x: Any) -> Any:
|
||||||
|
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): Object to be called.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The deallocate .data object.
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.data.untyped_storage().resize_(0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||||
"""Merge micro batches into a batch.
|
"""Merge micro batches into a batch.
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_flatten, tree_map
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
@ -12,7 +12,18 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|||||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
from ._utils import (
|
||||||
|
clone,
|
||||||
|
detach,
|
||||||
|
get_batch_size,
|
||||||
|
get_micro_batch,
|
||||||
|
merge_batch,
|
||||||
|
model_forward,
|
||||||
|
release_tensor_data,
|
||||||
|
require_grad,
|
||||||
|
retain_grad,
|
||||||
|
to_device,
|
||||||
|
)
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
|
||||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||||
@ -24,21 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
|||||||
req.wait()
|
req.wait()
|
||||||
|
|
||||||
|
|
||||||
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
|
||||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
|
||||||
|
|
||||||
This method should be called right after the output tensor has been
|
|
||||||
sent to the next pipeline stage. At this point, the output tensor is
|
|
||||||
only useful for its '.grad_fn' field, and not its '.data'.
|
|
||||||
"""
|
|
||||||
if (out is None) or (not deallocate_pipeline_outputs):
|
|
||||||
return
|
|
||||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
|
||||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
|
||||||
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
|
|
||||||
out.data.untyped_storage().resize_(0)
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -409,6 +405,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
|
micro_batch: Optional[dict],
|
||||||
input_obj: Optional[dict],
|
input_obj: Optional[dict],
|
||||||
criterion: Callable,
|
criterion: Callable,
|
||||||
accum_loss: Optional[torch.Tensor] = None,
|
accum_loss: Optional[torch.Tensor] = None,
|
||||||
@ -427,18 +424,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||||
"""
|
"""
|
||||||
# Load input ids, attention mask and labels
|
# Load input ids, attention mask and labels
|
||||||
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
||||||
|
|
||||||
# for the first stage, input_obj is None
|
|
||||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||||
# Only attention_mask from micro_batch is used
|
# Only attention_mask from micro_batch is used
|
||||||
|
|
||||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
# fwd calculate
|
# fwd calculate
|
||||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
|
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
|
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||||
|
|
||||||
# last layer in model
|
# last layer in model
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
loss = criterion(output_obj) / self.num_microbatch
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.detach())
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
@ -452,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
|
micro_batch: Optional[dict],
|
||||||
input_obj: Optional[dict],
|
input_obj: Optional[dict],
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
@ -462,7 +460,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
model_chunk_id (int): The current model chunk idx;
|
model_chunk_id (int): The current model chunk idx;
|
||||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||||
input_obj (Optional[dict]): x.
|
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
|
||||||
output_obj (Union[dict, torch.Tensor]): y.
|
output_obj (Union[dict, torch.Tensor]): y.
|
||||||
output_obj_grad (dict): dy.
|
output_obj_grad (dict): dy.
|
||||||
|
|
||||||
@ -471,20 +469,52 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
# calculate bwd b step ; only dx = w*dy;
|
# calculate bwd b step ; only dx = w*dy;
|
||||||
|
|
||||||
# Retain the grad on the input_obj.
|
# Retain the grad on the input_obj. No need retain_grad microbatch
|
||||||
tree_map(retain_grad, input_obj)
|
if input_obj is not None:
|
||||||
|
tree_map(retain_grad, input_obj)
|
||||||
|
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
# x, y, dy list for backward_by_grad; Type: list[tensor];
|
||||||
# loss backward; output_obj is loss; so output_obj_grad should be None
|
input_obj_ = []
|
||||||
|
output_obj_ = []
|
||||||
|
output_obj_grad_ = []
|
||||||
|
|
||||||
|
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||||
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
input_obj_, _ = tree_flatten(micro_batch)
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
assert output_obj_grad is None
|
assert output_obj_grad is None
|
||||||
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
|
output_obj_.append(output_obj) # LOSS
|
||||||
|
output_obj_grad_.append(output_obj_grad) # None
|
||||||
|
|
||||||
|
# For other chunk stage, use input_obj as input_obj_;
|
||||||
|
else:
|
||||||
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad,
|
grad=output_obj_grad_,
|
||||||
inputs=input_obj,
|
inputs=input_obj_,
|
||||||
retain_graph=True,
|
retain_graph=True,
|
||||||
)
|
)
|
||||||
return input_obj.grad
|
|
||||||
|
# Format output_obj_grad
|
||||||
|
input_obj_grad = {}
|
||||||
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
for k, v in micro_batch.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||||
|
input_obj_grad[k] = v.grad
|
||||||
|
else:
|
||||||
|
for k, v in input_obj.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||||
|
input_obj_grad[k] = v.grad
|
||||||
|
return input_obj_grad
|
||||||
|
|
||||||
def backward_w_step(
|
def backward_w_step(
|
||||||
self,
|
self,
|
||||||
@ -508,12 +538,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
# calculate bwd w step ; only dw = x*dy;
|
# calculate bwd w step ; only dw = x*dy;
|
||||||
|
|
||||||
|
# y, dy list for w backward_by_grad; Type: list[tensor];
|
||||||
|
output_obj_ = []
|
||||||
|
output_obj_grad_ = []
|
||||||
|
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# loss backward; output_obj is loss
|
# loss backward; output_obj is loss;
|
||||||
output_obj_grad = None
|
output_obj_.append(output_obj) # LOSS
|
||||||
|
output_obj_grad_.append(None) # None
|
||||||
|
else:
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad,
|
grad=output_obj_grad_,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
retain_graph=False,
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
@ -543,9 +582,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||||
# Step1: recv fwd
|
# Step1: recv fwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from func param
|
# is first stage; get input from microbatch
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj = micro_batch
|
input_obj = None
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
else:
|
else:
|
||||||
@ -557,55 +596,75 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# Here, let input_obj.requires_grad_()
|
# Here, let input_obj.requires_grad_()
|
||||||
tree_map(torch.Tensor.requires_grad_, input_obj)
|
# if input_obj is not None:
|
||||||
|
if not isinstance(input_obj, torch.Tensor):
|
||||||
|
tree_map(require_grad, input_obj)
|
||||||
|
|
||||||
|
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||||
|
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
||||||
|
|
||||||
# Step2: fwd step
|
# Step2: fwd step
|
||||||
output_obj = self.forward_step(
|
output_obj = self.forward_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
|
micro_batch=micro_batch,
|
||||||
input_obj=input_obj,
|
input_obj=input_obj,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Step3:
|
||||||
|
# 3-1:detach output; detach output for send fwd;
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# We should not detach bwd LOSS
|
# We should not detach bwd LOSS
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
detached_output_obj = output_obj.clone().detach()
|
# detach output
|
||||||
|
detached_output_obj = tree_map(detach, output_obj)
|
||||||
|
# 3-2 clone detached_output_obj
|
||||||
|
detached_output_obj = tree_map(clone, detached_output_obj)
|
||||||
|
|
||||||
|
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
|
||||||
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
# We should not release_tensor_data bwd LOSS
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# release_tensor_data output
|
||||||
|
tree_map(release_tensor_data, output_obj)
|
||||||
|
|
||||||
|
# add input and output object for backward b
|
||||||
|
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
||||||
|
|
||||||
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||||
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
|
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||||
|
else:
|
||||||
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
|
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||||
|
|
||||||
# Step3: send fwd
|
|
||||||
# add output to send_fwd_buffer
|
# add output to send_fwd_buffer
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0: # chunk 0
|
||||||
# is last stage; send to local_send_forward_buffer
|
# is last stage; send to local_send_forward_buffer
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
self.local_send_forward_buffer.append(detached_output_obj)
|
self.local_send_forward_buffer.append(detached_output_obj)
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||||
else:
|
else: # chunk 1
|
||||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
# is first stage; end of fwd; do nothing
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||||
|
|
||||||
# add input and output object for backward b
|
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
|
||||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
|
||||||
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
|
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
|
||||||
# add output object for backward w
|
|
||||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
|
||||||
|
|
||||||
def schedule_b(
|
def schedule_b(
|
||||||
self,
|
self,
|
||||||
scheduled_node,
|
scheduled_node,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
# input_obj: Optional[dict],
|
|
||||||
# output_obj: Union[dict, torch.Tensor],
|
|
||||||
# output_obj_grad: Optional[dict],
|
|
||||||
):
|
):
|
||||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||||
|
|
||||||
@ -616,25 +675,24 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
Returns:
|
Returns:
|
||||||
Nothing.
|
Nothing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Step1: recv bwd
|
# Step1: recv bwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||||
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
|
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
else:
|
else:
|
||||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
output_tensor_grad = None
|
output_tensor_grad = None
|
||||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# get input and output object from buffer;
|
# get input and output object from buffer;
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# save output_tensor_grad for dw
|
# save output_tensor_grad for dw
|
||||||
@ -645,12 +703,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# we save output_tensor_grad here
|
# we save output_tensor_grad here
|
||||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||||
|
|
||||||
# _wait_p2p(recv_bwd_handles)
|
|
||||||
# Step2: bwd step
|
# Step2: bwd step
|
||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
micro_batch=micro_batch,
|
||||||
input_obj=input_obj,
|
input_obj=input_obj,
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_tensor_grad,
|
output_obj_grad=output_tensor_grad,
|
||||||
@ -777,8 +835,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
communication_func(scheduled_node.chunk)
|
communication_func(scheduled_node.chunk)
|
||||||
|
elif scheduled_node.type == "F":
|
||||||
if scheduled_node.type == "F":
|
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from types import MethodType
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -14,6 +16,7 @@ from colossalai.logging import disable_existing_loggers
|
|||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
@ -23,10 +26,32 @@ class MlpModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x)
|
hidden_states = layer(hidden_states)
|
||||||
return x
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def pp_linear_fwd(
|
||||||
|
forward,
|
||||||
|
data: torch.Tensor = None,
|
||||||
|
hidden_states: torch.Tensor = None,
|
||||||
|
stage_mgr: PipelineStageManager = None,
|
||||||
|
model_chunk_id: int = None,
|
||||||
|
):
|
||||||
|
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||||
|
# fwd end
|
||||||
|
if stage_mgr.is_first_stage() and model_chunk_id == 1:
|
||||||
|
return forward(hidden_states)
|
||||||
|
# fwd start
|
||||||
|
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||||
|
return {"hidden_states": forward(data)}
|
||||||
|
# fwd middle
|
||||||
|
else:
|
||||||
|
return {"hidden_states": forward(hidden_states)}
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
@ -561,19 +586,24 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
|
|
||||||
# init loss func
|
# init loss func
|
||||||
def criterion(x, *args, **kwargs):
|
def criterion(x, *args, **kwargs):
|
||||||
|
x = x["hidden_states"]
|
||||||
|
return (x * x).mean()
|
||||||
|
|
||||||
|
def criterion_base(x, *args, **kwargs):
|
||||||
return (x * x).mean()
|
return (x * x).mean()
|
||||||
|
|
||||||
# init model and input
|
# init model and input
|
||||||
batch_size = test_config["batch_size"]
|
batch_size = test_config["batch_size"]
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||||
in_dim = out_dim = 4096
|
in_dim = out_dim = 1024
|
||||||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||||
|
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||||
input_base = [t.clone() for t in data_iter]
|
# input_base = [t.clone() for t in data_iter]
|
||||||
|
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||||
model_base = deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -581,24 +611,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 0 or idx == 7:
|
if idx == 0 or idx == 7:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
elif rank == 1:
|
elif rank == 1:
|
||||||
# layer 1 & 6 to chunk 1 on rank1
|
# layer 1 & 6 to chunk 1 on rank1
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 1 or idx == 6:
|
if idx == 1 or idx == 6:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
elif rank == 2:
|
elif rank == 2:
|
||||||
# layer 2 & 5 to chunk 2 on rank2
|
# layer 2 & 5 to chunk 2 on rank2
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 2 or idx == 5:
|
if idx == 2 or idx == 5:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
else:
|
else:
|
||||||
# layer 3 & 4 to chunk 3 on rank3
|
# layer 3 & 4 to chunk 3 on rank3
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 3 or idx == 4:
|
if idx == 3 or idx == 4:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
|
|
||||||
# init optimizer
|
# init optimizer
|
||||||
@ -611,7 +661,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
result = scheduler.forward_backward_step(
|
result = scheduler.forward_backward_step(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
data_iter=iter(data_iter),
|
data_iter=iter([data_iter]),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer_pp,
|
optimizer=optimizer_pp,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
@ -624,26 +674,28 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
|
|
||||||
# assert memory
|
# assert memory
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
# w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||||
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
# output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3
|
||||||
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
# optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||||
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
|
print(
|
||||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
|
f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}"
|
||||||
|
)
|
||||||
|
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3)
|
||||||
else:
|
else:
|
||||||
# rank0 will also hold output;
|
# rank0 will also hold output;
|
||||||
print(
|
print(
|
||||||
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
||||||
)
|
)
|
||||||
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
||||||
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
(in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||||
)
|
)
|
||||||
|
|
||||||
##########################
|
##########################
|
||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
# fwd & bwd
|
# fwd & bwd
|
||||||
output_base = model_base(input_base[0])
|
output_base = model_base(input_base["data"])
|
||||||
loss_base = criterion(output_base)
|
loss_base = criterion_base(output_base)
|
||||||
loss_base.backward()
|
loss_base.backward()
|
||||||
optimizer_base.step()
|
optimizer_base.step()
|
||||||
|
|
||||||
@ -653,7 +705,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
# only chunk 1 stage 0 hold loss and output
|
# only chunk 1 stage 0 hold loss and output
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
assert_close(result["loss"], loss_base)
|
assert_close(result["loss"], loss_base)
|
||||||
assert_close(result["outputs"], output_base)
|
assert_close(result["outputs"]["hidden_states"], output_base)
|
||||||
|
|
||||||
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
||||||
##########################
|
##########################
|
||||||
@ -724,28 +776,108 @@ def run_with_hybridplugin(test_config):
|
|||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"batch_size": 8,
|
"pp_style": "zbv",
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
|
"ep_size": 1,
|
||||||
"pp_size": 4,
|
"pp_size": 4,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"num_model_chunk": 2,
|
"num_model_chunks": 2,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_moehybridplugin(test_config):
|
def run_with_moehybridplugin(test_config):
|
||||||
model_zoo.get_sub_registry("transformers_bert")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
test_config["use_lazy_init"] = False
|
# test_config["use_lazy_init"] = False
|
||||||
test_config["initial_scale"] = 2**16
|
test_config["initial_scale"] = 2**16
|
||||||
model_list = [
|
model_list = [
|
||||||
"transformers_bert",
|
"transformers_bert",
|
||||||
]
|
]
|
||||||
|
clear_layout_converter()
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if name in model_list:
|
||||||
|
# base param
|
||||||
|
model = model_fn()
|
||||||
|
data = data_gen_fn()
|
||||||
|
print(f"data {data}")
|
||||||
|
criterion = loss_fn
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
||||||
|
|
||||||
|
output = model(**data)
|
||||||
|
loss = criterion(output)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
print(f"output {output}")
|
||||||
|
|
||||||
|
# # pp param
|
||||||
|
# model_pp = deepcopy(model)
|
||||||
|
# data_pp = deepcopy(data)
|
||||||
|
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
|
||||||
|
|
||||||
|
# # init pipeline graph
|
||||||
|
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
|
||||||
|
# mem_f = 34 * h + 5 * a * s
|
||||||
|
# mem_w = -32 * h
|
||||||
|
# mem_b = -mem_w - mem_f
|
||||||
|
# graph = PipelineGraph(
|
||||||
|
# n_stage=test_config["pp_size"],
|
||||||
|
# n_micro=test_config["num_microbatches"],
|
||||||
|
# f_cost=1,
|
||||||
|
# b_cost=1,
|
||||||
|
# w_cost=1,
|
||||||
|
# c_cost=1,
|
||||||
|
# f_mem=mem_f,
|
||||||
|
# b_mem=mem_b,
|
||||||
|
# w_mem=mem_w,
|
||||||
|
# # max_mem=mem_f * (p * 2 + m_offset),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# zbv_schedule = graph.get_v_schedule()
|
||||||
|
|
||||||
|
# test_config["scheduler_nodes"] = zbv_schedule
|
||||||
|
# plugin = MoeHybridParallelPlugin(
|
||||||
|
# **test_config
|
||||||
|
# )
|
||||||
|
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
|
||||||
|
# model = model_pp,
|
||||||
|
# optimizer = optimizer_pp,
|
||||||
|
# criterion = criterion,
|
||||||
|
# dataloader = data_pp,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# output_pp = plugin.execute_pipeline(
|
||||||
|
# data_iter=iter(data),
|
||||||
|
# model=model,
|
||||||
|
# criterion=criterion,
|
||||||
|
# optimizer=optimizer,
|
||||||
|
# return_loss = True,
|
||||||
|
# return_outputs = True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
# TODO:6) support booster & Hybrid base 4)
|
# TODO:6) support booster & Hybrid base 4)
|
||||||
|
|
||||||
|
|
||||||
# TODO:7) support booster & MoEHybrid base 4)
|
# TODO:7) support booster & MoEHybrid base 4)
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"pp_style": "zbv",
|
||||||
|
"tp_size": 1,
|
||||||
|
"ep_size": 1,
|
||||||
|
"pp_size": 4,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_with_booster_moehybridplugin(test_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
@ -754,6 +886,7 @@ def run_dist(rank, world_size, port):
|
|||||||
# run_fwd_bwd_iter_input()
|
# run_fwd_bwd_iter_input()
|
||||||
run_fwd_bwd_vschedule_with_optim()
|
run_fwd_bwd_vschedule_with_optim()
|
||||||
# run_with_moehybridplugin()
|
# run_with_moehybridplugin()
|
||||||
|
# run_with_booster_moehybridplugin()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
Loading…
Reference in New Issue
Block a user