mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 02:57:20 +00:00
[feat] add fwd_bwd_step, run_fwd_only;
This commit is contained in:
parent
48ba22dbfd
commit
6af81d8c0d
@ -1,5 +1,5 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
@ -696,6 +696,54 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj_grad=output_obj_grad,
|
output_obj_grad=output_obj_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_forward_only(
|
||||||
|
self,
|
||||||
|
model_chunk: Union[ModuleList, Module],
|
||||||
|
data_iter: Iterable,
|
||||||
|
criterion: Callable[..., Any],
|
||||||
|
return_loss: bool = False,
|
||||||
|
return_outputs: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
assert self.forward_only
|
||||||
|
|
||||||
|
# prepare batch
|
||||||
|
self.load_batch(data_iter)
|
||||||
|
|
||||||
|
# prepare accum loss & output
|
||||||
|
accum_loss = None
|
||||||
|
|
||||||
|
# reset accum loss at fwd end;
|
||||||
|
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||||
|
|
||||||
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||||
|
|
||||||
|
it = 0
|
||||||
|
# while we still have schedules_node in self.schedules
|
||||||
|
while it < len(self.schedules):
|
||||||
|
scheduled_node = self.schedules[it]
|
||||||
|
|
||||||
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
|
# communication
|
||||||
|
if scheduled_node.type == "RECV_FORWARD":
|
||||||
|
self.recv_forward(scheduled_node.chunk)
|
||||||
|
elif scheduled_node.type == "SEND_FORWARD":
|
||||||
|
self.send_forward(scheduled_node.chunk)
|
||||||
|
if scheduled_node.type == "F":
|
||||||
|
self.schedule_f(
|
||||||
|
scheduled_node=scheduled_node,
|
||||||
|
model_chunk=model_chunk,
|
||||||
|
model_chunk_id=scheduled_node.chunk,
|
||||||
|
criterion=criterion,
|
||||||
|
accum_loss=accum_loss,
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
it += 1
|
||||||
|
# return loss & output
|
||||||
|
if outputs is not None:
|
||||||
|
outputs = merge_batch(outputs)
|
||||||
|
return {"loss": accum_loss, "outputs": outputs}
|
||||||
|
|
||||||
def run_forward_backward(
|
def run_forward_backward(
|
||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
@ -704,7 +752,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
optimizer: Optional[OptimizerWrapper] = None,
|
optimizer: Optional[OptimizerWrapper] = None,
|
||||||
return_loss: bool = False,
|
return_loss: bool = False,
|
||||||
return_outputs: bool = False,
|
return_outputs: bool = False,
|
||||||
):
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||||
"""
|
"""
|
||||||
@ -770,3 +818,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
return {"loss": accum_loss, "outputs": outputs}
|
return {"loss": accum_loss, "outputs": outputs}
|
||||||
|
|
||||||
|
def forward_backward_step(
|
||||||
|
self,
|
||||||
|
model_chunk: Union[ModuleList, Module],
|
||||||
|
data_iter: Iterable,
|
||||||
|
criterion: Callable[..., Any],
|
||||||
|
optimizer: Optional[OptimizerWrapper] = None,
|
||||||
|
return_loss: bool = False,
|
||||||
|
return_outputs: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
||||||
|
data_iter (Iterable): Data iterator.
|
||||||
|
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||||
|
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||||
|
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||||
|
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict with keys: 'loss' and 'outputs'.
|
||||||
|
"""
|
||||||
|
self.forward_only = not torch.is_grad_enabled()
|
||||||
|
if optimizer is None:
|
||||||
|
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||||
|
|
||||||
|
if self.forward_only:
|
||||||
|
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
||||||
|
else:
|
||||||
|
result = self.run_forward_backward(
|
||||||
|
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
@ -644,10 +644,10 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
graph = PipelineGraph(
|
graph = PipelineGraph(
|
||||||
n_stage=world_size,
|
n_stage=world_size,
|
||||||
n_micro=num_microbatch,
|
n_micro=num_microbatch,
|
||||||
f_cost=6,
|
f_cost=1,
|
||||||
b_cost=6,
|
b_cost=1,
|
||||||
w_cost=6,
|
w_cost=1,
|
||||||
c_cost=6,
|
c_cost=1,
|
||||||
f_mem=mem_f,
|
f_mem=mem_f,
|
||||||
b_mem=mem_b,
|
b_mem=mem_b,
|
||||||
w_mem=mem_w,
|
w_mem=mem_w,
|
||||||
@ -714,7 +714,7 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
result = scheduler.run_forward_backward(
|
result = scheduler.forward_backward_step(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
data_iter=iter(data_iter),
|
data_iter=iter(data_iter),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
@ -793,6 +793,25 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
assert val_base[:2] == val_pp
|
assert val_base[:2] == val_pp
|
||||||
|
|
||||||
|
|
||||||
|
# 4) support Hybrid base 3)
|
||||||
|
def run_with_hybrid(
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
port: int,
|
||||||
|
num_microbatch: int,
|
||||||
|
batch_size: int,
|
||||||
|
num_model_chunk: int,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# 5) support MoE base 3)
|
||||||
|
|
||||||
|
# 6) support booster & Hybrid base 4)
|
||||||
|
|
||||||
|
# 6) support booster & MoE base 4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("num_microbatch", [4])
|
@pytest.mark.parametrize("num_microbatch", [4])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
Loading…
Reference in New Issue
Block a user