[feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;

This commit is contained in:
duanjunwen
2024-08-27 09:31:38 +00:00
parent 9e0bd1af00
commit 8b37323f16
2 changed files with 247 additions and 28 deletions

View File

@@ -495,7 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node,
model_chunk: torch.nn.ModuleList,
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
@@ -506,7 +505,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
@@ -518,7 +516,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0:
# is first stage; get input from func param
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = input_obj
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
@@ -671,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
input_obj: Optional[dict],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
@@ -683,7 +680,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# # prepare batch
self.load_batch(data_iter)
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}")
print(
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
)
it = 0
# while we still have schedules_node in self.schedules
@@ -707,7 +706,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
input_obj=input_obj,
criterion=criterion,
accum_loss=return_loss,
outputs=return_outputs,