[feat] add fwd_bwd_step, run_fwd_only;

This commit is contained in:
duanjunwen 2024-08-30 02:47:52 +00:00
parent 48ba22dbfd
commit 6af81d8c0d
2 changed files with 108 additions and 7 deletions

View File

@ -1,5 +1,5 @@
from functools import partial from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.cuda import torch.cuda
@ -696,6 +696,54 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_grad=output_obj_grad, output_obj_grad=output_obj_grad,
) )
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
it = 0
# while we still have schedules_node in self.schedules
while it < len(self.schedules):
scheduled_node = self.schedules[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
if scheduled_node.type == "RECV_FORWARD":
self.recv_forward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_FORWARD":
self.send_forward(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
it += 1
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward( def run_forward_backward(
self, self,
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
@ -704,7 +752,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
optimizer: Optional[OptimizerWrapper] = None, optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False, return_outputs: bool = False,
): ) -> Dict:
""" """
Runs Zerobubble schedule, with communication between pipeline stages. Runs Zerobubble schedule, with communication between pipeline stages.
""" """
@ -770,3 +818,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
return result

View File

@ -644,10 +644,10 @@ def run_fwd_bwd_vschedule_with_optim(
graph = PipelineGraph( graph = PipelineGraph(
n_stage=world_size, n_stage=world_size,
n_micro=num_microbatch, n_micro=num_microbatch,
f_cost=6, f_cost=1,
b_cost=6, b_cost=1,
w_cost=6, w_cost=1,
c_cost=6, c_cost=1,
f_mem=mem_f, f_mem=mem_f,
b_mem=mem_b, b_mem=mem_b,
w_mem=mem_w, w_mem=mem_w,
@ -714,7 +714,7 @@ def run_fwd_bwd_vschedule_with_optim(
) )
torch.cuda.synchronize() torch.cuda.synchronize()
result = scheduler.run_forward_backward( result = scheduler.forward_backward_step(
model_chunk=local_chunk, model_chunk=local_chunk,
data_iter=iter(data_iter), data_iter=iter(data_iter),
criterion=criterion, criterion=criterion,
@ -793,6 +793,25 @@ def run_fwd_bwd_vschedule_with_optim(
assert val_base[:2] == val_pp assert val_base[:2] == val_pp
# 4) support Hybrid base 3)
def run_with_hybrid(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
pass
# 5) support MoE base 3)
# 6) support booster & Hybrid base 4)
# 6) support booster & MoE base 4)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4]) @pytest.mark.parametrize("num_microbatch", [4])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])