mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 21:25:53 +00:00
[fix] fix zbv wait_handle
This commit is contained in:
parent
5c2ebbfd48
commit
cf86c1b1c5
@ -115,10 +115,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
self.output_tensors_grad_dw = [[], []]
|
self.output_tensors_grad_dw = [[], []]
|
||||||
|
|
||||||
# buffer for communication
|
# buffer for communication
|
||||||
self.send_forward_buffer = [[], []]
|
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||||
self.recv_forward_buffer = [[], []]
|
self.recv_forward_buffer = [
|
||||||
self.send_backward_buffer = [[], []]
|
[],
|
||||||
self.recv_backward_buffer = [[], []]
|
[],
|
||||||
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||||
|
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||||
|
self.recv_backward_buffer = [
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||||
|
|
||||||
# y buffer for local send fwd
|
# y buffer for local send fwd
|
||||||
self.local_send_forward_buffer = []
|
self.local_send_forward_buffer = []
|
||||||
@ -257,7 +263,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -280,7 +286,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
@ -316,7 +322,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -339,7 +345,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
@ -651,9 +657,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from microbatch
|
# is first stage; get input from microbatch
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj = None
|
input_obj = None # (tensor, wait_handle)
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in input_obj[1]:
|
||||||
|
h.wait()
|
||||||
|
input_obj = input_obj[0]
|
||||||
else:
|
else:
|
||||||
# is last stage; recv from local
|
# is last stage; recv from local
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
@ -661,7 +670,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# not last stage; recv from next
|
# not last stage; recv from next
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in input_obj[1]:
|
||||||
|
h.wait()
|
||||||
|
input_obj = input_obj[0]
|
||||||
# Here, let input_obj.requires_grad_()
|
# Here, let input_obj.requires_grad_()
|
||||||
# if input_obj is not None:
|
# if input_obj is not None:
|
||||||
if not isinstance(input_obj, torch.Tensor):
|
if not isinstance(input_obj, torch.Tensor):
|
||||||
@ -751,6 +762,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in output_tensor_grad[1]:
|
||||||
|
h.wait()
|
||||||
|
output_tensor_grad = output_tensor_grad[0]
|
||||||
else:
|
else:
|
||||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
@ -758,6 +772,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
for h in output_tensor_grad[1]:
|
||||||
|
h.wait()
|
||||||
|
output_tensor_grad = output_tensor_grad[0]
|
||||||
|
|
||||||
# get input and output object from buffer;
|
# get input and output object from buffer;
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
@ -886,9 +903,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
wait_handle = communication_func(scheduled_node.chunk)
|
wait_handle = communication_func(scheduled_node.chunk)
|
||||||
self.wait_handles.append(wait_handle)
|
self.wait_handles.append(wait_handle)
|
||||||
elif scheduled_node.type == "F":
|
elif scheduled_node.type == "F":
|
||||||
for h in self.wait_handles:
|
|
||||||
for hh in h:
|
|
||||||
hh.wait()
|
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
@ -898,9 +912,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
elif scheduled_node.type == "B":
|
elif scheduled_node.type == "B":
|
||||||
for h in self.wait_handles:
|
|
||||||
for hh in h:
|
|
||||||
hh.wait()
|
|
||||||
self.schedule_b(
|
self.schedule_b(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
@ -914,9 +925,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
for h in self.wait_handles:
|
|
||||||
for hh in h:
|
|
||||||
hh.wait()
|
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user