[fix] fix zbv wait_handle

This commit is contained in:
duanjunwen 2024-11-15 07:56:14 +00:00
parent 5c2ebbfd48
commit cf86c1b1c5

View File

@ -115,10 +115,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_grad_dw = [[], []] self.output_tensors_grad_dw = [[], []]
# buffer for communication # buffer for communication
self.send_forward_buffer = [[], []] self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_forward_buffer = [[], []] self.recv_forward_buffer = [
self.send_backward_buffer = [[], []] [],
self.recv_backward_buffer = [[], []] [],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_backward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
# y buffer for local send fwd # y buffer for local send fwd
self.local_send_forward_buffer = [] self.local_send_forward_buffer = []
@ -257,7 +263,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles return wait_handles
else: else:
@ -280,7 +286,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -316,7 +322,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles return wait_handles
else: else:
@ -339,7 +345,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -651,9 +657,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0: if model_chunk_id == 0:
# is first stage; get input from microbatch # is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = None input_obj = None # (tensor, wait_handle)
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
else: else:
# is last stage; recv from local # is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
@ -661,7 +670,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# not last stage; recv from next # not last stage; recv from next
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
# Here, let input_obj.requires_grad_() # Here, let input_obj.requires_grad_()
# if input_obj is not None: # if input_obj is not None:
if not isinstance(input_obj, torch.Tensor): if not isinstance(input_obj, torch.Tensor):
@ -751,6 +762,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# chunk0 not last stage; recv output_grad from recv_backward_buffer # chunk0 not last stage; recv output_grad from recv_backward_buffer
else: else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
else: else:
# chunk1, is first stage; recv LOSS from local send bwd buffer # chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
@ -758,6 +772,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# chunk1, not first stage; recv output_grad from recv_backward_buffer # chunk1, not first stage; recv output_grad from recv_backward_buffer
else: else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
# get input and output object from buffer; # get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0)
@ -886,9 +903,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
wait_handle = communication_func(scheduled_node.chunk) wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle) self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F": elif scheduled_node.type == "F":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_f( self.schedule_f(
scheduled_node=scheduled_node, scheduled_node=scheduled_node,
model_chunk=model_chunk, model_chunk=model_chunk,
@ -898,9 +912,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs, outputs=outputs,
) )
elif scheduled_node.type == "B": elif scheduled_node.type == "B":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_b( self.schedule_b(
scheduled_node=scheduled_node, scheduled_node=scheduled_node,
model_chunk=model_chunk, model_chunk=model_chunk,
@ -914,9 +925,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)