[fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx;

This commit is contained in:
duanjunwen 2024-09-26 10:50:44 +00:00
parent bb0390c90d
commit 64ceea746f

View File

@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
optimizer: OptimizerWrapper, optimizer: OptimizerWrapper,
micro_batch: Optional[dict], # micro_batch: Optional[dict],
input_obj: Optional[dict], input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict], output_obj_grad: Optional[dict],
@ -480,9 +480,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# For chunk 0 stage 0, use micro_batch as input_obj_ # For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten(micro_batch) # input_obj_, _ = tree_flatten(micro_batch)
output_obj_, _ = tree_flatten(output_obj) # y # output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy # output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
return None
# For loss backward; output_obj is loss; output_obj_grad should be None # For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -512,9 +513,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Format output_obj_grad # Format output_obj_grad
input_obj_grad = {} input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items(): # for k, v in micro_batch.items():
if isinstance(v, torch.Tensor) and v.grad is not None: # if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad # input_obj_grad[k] = v.grad
pass
else: else:
for k, v in input_obj.items(): for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None: if isinstance(v, torch.Tensor) and v.grad is not None:
@ -643,7 +645,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
tree_map(release_tensor_data, output_obj) tree_map(release_tensor_data, output_obj)
# add input and output object for backward b # add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) # self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
self.input_tensors[model_chunk_id].append(input_obj)
# for bwd b&w, we only need the graph(grad_fn) of output_obj # for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj; # Do not release_tensor_data loss, release_tensor_data other output_obj;
@ -701,7 +704,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer; # get input and output object from buffer;
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) # micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw # save output_tensor_grad for dw
@ -717,7 +721,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
optimizer=optimizer, optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj, input_obj=input_obj,
output_obj=output_obj, output_obj=output_obj,
output_obj_grad=output_tensor_grad, output_obj_grad=output_tensor_grad,
@ -838,6 +841,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# while we still have schedules_node in self.schedules # while we still have schedules_node in self.schedules
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
print(f"schedule {schedule}")
for it in range(len(schedule)): for it in range(len(schedule)):
scheduled_node = schedule[it] scheduled_node = schedule[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: