[feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;

This commit is contained in:
duanjunwen 2024-08-27 09:31:38 +00:00
parent 9e0bd1af00
commit 8b37323f16
2 changed files with 247 additions and 28 deletions

View File

@ -495,7 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node, scheduled_node,
model_chunk: torch.nn.ModuleList, model_chunk: torch.nn.ModuleList,
model_chunk_id: int, model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable, criterion: Callable,
accum_loss: Optional[torch.Tensor] = None, accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None, outputs: Optional[List[Any]] = None,
@ -506,7 +505,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node: scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx; model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function; criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
@ -518,7 +516,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0: if model_chunk_id == 0:
# is first stage; get input from func param # is first stage; get input from func param
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = input_obj input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else: else:
@ -671,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
def run_forward_backward( def run_forward_backward(
self, self,
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
input_obj: Optional[dict],
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[..., Any], criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None, optimizer: Optional[OptimizerWrapper] = None,
@ -683,7 +680,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
""" """
# # prepare batch # # prepare batch
self.load_batch(data_iter) self.load_batch(data_iter)
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}") print(
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
)
it = 0 it = 0
# while we still have schedules_node in self.schedules # while we still have schedules_node in self.schedules
@ -707,7 +706,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node=scheduled_node, scheduled_node=scheduled_node,
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
input_obj=input_obj,
criterion=criterion, criterion=criterion,
accum_loss=return_loss, accum_loss=return_loss,
outputs=return_outputs, outputs=return_outputs,

View File

@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
return num_params, num_params_trainable return num_params, num_params_trainable
# Test run_forward_backward with baseline; # Test iter input & multiple microbatch
def test_run_fwd_bwd_base( def test_run_fwd_bwd_iter_input(
rank: int, rank: int,
world_size: int, world_size: int,
port: int, port: int,
@ -47,7 +47,7 @@ def test_run_fwd_bwd_base(
rank = dist.get_rank() rank = dist.get_rank()
pp_size = world_size pp_size = world_size
pg_mesh = ProcessGroupMesh(pp_size) pg_mesh = ProcessGroupMesh(pp_size)
num_microbatch = 4
# stage_manager # stage_manager
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size) stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
@ -55,6 +55,7 @@ def test_run_fwd_bwd_base(
zbv_schedule = [ zbv_schedule = [
# stage 0 # stage 0
[ [
# microbatch 0
# chunk 0 fwd # chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
@ -73,9 +74,67 @@ def test_run_fwd_bwd_base(
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
], ],
# stage 1 # stage 1
[ [
# microbatch 0
# chunk 0 fwd # chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
@ -94,9 +153,67 @@ def test_run_fwd_bwd_base(
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
], ],
# stage 2 # stage 2
[ [
# microbatch 0
# chunk 0 fwd # chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
@ -114,10 +231,68 @@ def test_run_fwd_bwd_base(
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3),
], ],
# stage 3 # stage 3
[ [
# microbatch 0
# chunk 0 fwd # chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
@ -136,6 +311,63 @@ def test_run_fwd_bwd_base(
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3),
], ],
] ]
@ -143,7 +375,7 @@ def test_run_fwd_bwd_base(
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager, stage_manager=stage_manager,
num_model_chunks=pp_size, num_model_chunks=pp_size,
num_microbatch=1, num_microbatch=num_microbatch,
overlap_p2p=False, overlap_p2p=False,
) )
@ -152,14 +384,15 @@ def test_run_fwd_bwd_base(
return (x * x).mean() return (x * x).mean()
# init model and input # init model and input
batch_size = 4
num_layers = 8 num_layers = 8
in_dim = out_dim = 8 in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
# data_iter = [input0] data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = input0.clone() [t.clone() for t in data_iter]
model_base = deepcopy(model) model_base = deepcopy(model)
if rank == 0: if rank == 0:
@ -193,9 +426,7 @@ def test_run_fwd_bwd_base(
torch.cuda.synchronize() torch.cuda.synchronize()
scheduler.run_forward_backward( scheduler.run_forward_backward(
model_chunk=local_chunk, model_chunk=local_chunk,
input_obj=input0, data_iter=iter(data_iter),
# data_iter=iter(data_iter),
data_iter=None,
criterion=criterion, criterion=criterion,
optimizer=None, optimizer=None,
return_loss=None, return_loss=None,
@ -206,8 +437,7 @@ def test_run_fwd_bwd_base(
# Fwd bwd for base # Fwd bwd for base
########################## ##########################
# fwd & bwd # fwd & bwd
output_base = model_base(input_base) output_base = model_base(data_iter[0])
# loss_base = output_base.mean()
loss_base = criterion(output_base) loss_base = criterion(output_base)
loss_base.backward() loss_base.backward()
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -245,15 +475,6 @@ def test_run_fwd_bwd_base(
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# Test iter input & multiple microbatch
def test_run_fwd_bwd_iter_input(
rank: int,
world_size: int,
port: int,
):
pass
@pytest.mark.dist @pytest.mark.dist
# @pytest.mark.parametrize("num_microbatch", [4]) # @pytest.mark.parametrize("num_microbatch", [4])
# @pytest.mark.parametrize("batch_size", [4]) # @pytest.mark.parametrize("batch_size", [4])
@ -261,7 +482,7 @@ def test_run_fwd_bwd_iter_input(
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pp(): def test_pp():
spawn( spawn(
test_run_fwd_bwd_base, test_run_fwd_bwd_iter_input,
nprocs=4, nprocs=4,
) )