mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 11:06:25 +00:00
[feat] add fwd_bwd_step, run_fwd_only;
This commit is contained in:
parent
48ba22dbfd
commit
6af81d8c0d
@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
@ -696,6 +696,54 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
output_obj_grad=output_obj_grad,
|
||||
)
|
||||
|
||||
def run_forward_only(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
assert self.forward_only
|
||||
|
||||
# prepare batch
|
||||
self.load_batch(data_iter)
|
||||
|
||||
# prepare accum loss & output
|
||||
accum_loss = None
|
||||
|
||||
# reset accum loss at fwd end;
|
||||
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
self.recv_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_FORWARD":
|
||||
self.send_forward(scheduled_node.chunk)
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
it += 1
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
@ -704,7 +752,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
):
|
||||
) -> Dict:
|
||||
"""
|
||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||
"""
|
||||
@ -770,3 +818,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
if self.forward_only:
|
||||
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
||||
else:
|
||||
result = self.run_forward_backward(
|
||||
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
return result
|
||||
|
@ -644,10 +644,10 @@ def run_fwd_bwd_vschedule_with_optim(
|
||||
graph = PipelineGraph(
|
||||
n_stage=world_size,
|
||||
n_micro=num_microbatch,
|
||||
f_cost=6,
|
||||
b_cost=6,
|
||||
w_cost=6,
|
||||
c_cost=6,
|
||||
f_cost=1,
|
||||
b_cost=1,
|
||||
w_cost=1,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
@ -714,7 +714,7 @@ def run_fwd_bwd_vschedule_with_optim(
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
result = scheduler.run_forward_backward(
|
||||
result = scheduler.forward_backward_step(
|
||||
model_chunk=local_chunk,
|
||||
data_iter=iter(data_iter),
|
||||
criterion=criterion,
|
||||
@ -793,6 +793,25 @@ def run_fwd_bwd_vschedule_with_optim(
|
||||
assert val_base[:2] == val_pp
|
||||
|
||||
|
||||
# 4) support Hybrid base 3)
|
||||
def run_with_hybrid(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
num_microbatch: int,
|
||||
batch_size: int,
|
||||
num_model_chunk: int,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# 5) support MoE base 3)
|
||||
|
||||
# 6) support booster & Hybrid base 4)
|
||||
|
||||
# 6) support booster & MoE base 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
|
Loading…
Reference in New Issue
Block a user