mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[feat] update test; rm comments;
This commit is contained in:
@@ -353,7 +353,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
|
||||
# bwd chunk1 is left V;
|
||||
else:
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}")
|
||||
################
|
||||
# chunk = 1 && is_last_stage
|
||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||
@@ -409,7 +408,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
# print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}")
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
@@ -537,11 +535,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
input_obj = micro_batch
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
@@ -619,8 +618,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
|
||||
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
@@ -643,7 +640,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
)
|
||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
|
||||
|
||||
# Step3: send bwd
|
||||
if model_chunk_id == 0:
|
||||
@@ -748,9 +744,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
"""
|
||||
# prepare batch
|
||||
self.load_batch(data_iter)
|
||||
print(
|
||||
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
||||
)
|
||||
|
||||
# prepare accum loss & output
|
||||
accum_loss = None
|
||||
@@ -762,12 +755,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||
|
||||
# while we still have schedules_node in self.schedules
|
||||
for it in range(len(self.schedules)):
|
||||
scheduled_node = self.schedules[it]
|
||||
|
||||
print(
|
||||
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
|
||||
)
|
||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
|
Reference in New Issue
Block a user