[feat] update test; rm comments;

This commit is contained in:
duanjunwen
2024-09-02 09:50:47 +00:00
parent a7b767b071
commit 6d18d38d5c
4 changed files with 128 additions and 214 deletions

View File

@@ -353,7 +353,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# bwd chunk1 is left V;
else:
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}")
################
# chunk = 1 && is_last_stage
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
@@ -409,7 +408,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
# print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}")
return loss
else:
return output_obj
@@ -537,11 +535,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Returns:
Nothing.
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from func param
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
input_obj = micro_batch
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
@@ -619,8 +618,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
@@ -643,7 +640,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
)
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
# Step3: send bwd
if model_chunk_id == 0:
@@ -748,9 +744,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# prepare batch
self.load_batch(data_iter)
print(
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
)
# prepare accum loss & output
accum_loss = None
@@ -762,12 +755,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
for it in range(len(self.schedules)):
scheduled_node = self.schedules[it]
print(
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
)
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]