[fix] fix communication_map;

This commit is contained in:
duanjunwen 2024-08-30 05:56:02 +00:00
parent 8eb6eac225
commit a7b767b071

View File

@ -60,6 +60,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# P2P communication # P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
# init communication map
self.communication_map = {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}
# init buffer # init buffer
self._free_buffers() self._free_buffers()
@ -162,14 +170,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def communication_func_map(self, node_type: str):
return {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}[node_type]
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV. For ZBV.
@ -718,7 +718,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
# communication # communication
communication_func = self.communication_func_map(scheduled_node.type) communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk) communication_func(scheduled_node.chunk)
if scheduled_node.type == "F": if scheduled_node.type == "F":
self.schedule_f( self.schedule_f(
@ -770,7 +770,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication # communication
communication_func = self.communication_func_map(scheduled_node.type) communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk) communication_func(scheduled_node.chunk)
if scheduled_node.type == "F": if scheduled_node.type == "F":