[fix] fix wait handle in run_fwd_bwd

This commit is contained in:
duanjunwen 2024-11-18 02:50:14 +00:00
parent f48a85e91d
commit 9a21f87ed6

View File

@ -899,7 +899,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication # communication
communication_func = self.communication_map[scheduled_node.type] communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk) wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle) # We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F": elif scheduled_node.type == "F":
self.schedule_f( self.schedule_f(
scheduled_node=scheduled_node, scheduled_node=scheduled_node,