mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
This commit is contained in:
@@ -495,7 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node,
|
||||
model_chunk: torch.nn.ModuleList,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
@@ -506,7 +505,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
input_obj (Optional[dict]): x;
|
||||
criterion (Callable): loss function;
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
@@ -518,7 +516,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = input_obj
|
||||
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
@@ -671,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
input_obj: Optional[dict],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
@@ -683,7 +680,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
# # prepare batch
|
||||
self.load_batch(data_iter)
|
||||
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}")
|
||||
print(
|
||||
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
||||
)
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
@@ -707,7 +706,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=return_loss,
|
||||
outputs=return_outputs,
|
||||
|
Reference in New Issue
Block a user