mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 19:16:42 +00:00
[fix] fix communication_map;
This commit is contained in:
parent
8eb6eac225
commit
a7b767b071
@ -60,6 +60,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# P2P communication
|
# P2P communication
|
||||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||||
|
|
||||||
|
# init communication map
|
||||||
|
self.communication_map = {
|
||||||
|
"SEND_FORWARD": self.send_forward,
|
||||||
|
"RECV_FORWARD": self.recv_forward,
|
||||||
|
"SEND_BACKWARD": self.send_backward,
|
||||||
|
"RECV_BACKWARD": self.recv_backward,
|
||||||
|
}
|
||||||
|
|
||||||
# init buffer
|
# init buffer
|
||||||
self._free_buffers()
|
self._free_buffers()
|
||||||
|
|
||||||
@ -162,14 +170,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||||
return model_chunk_id
|
return model_chunk_id
|
||||||
|
|
||||||
def communication_func_map(self, node_type: str):
|
|
||||||
return {
|
|
||||||
"SEND_FORWARD": self.send_forward,
|
|
||||||
"RECV_FORWARD": self.recv_forward,
|
|
||||||
"SEND_BACKWARD": self.send_backward,
|
|
||||||
"RECV_BACKWARD": self.recv_backward,
|
|
||||||
}[node_type]
|
|
||||||
|
|
||||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||||
For ZBV.
|
For ZBV.
|
||||||
@ -718,7 +718,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_func_map(scheduled_node.type)
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
communication_func(scheduled_node.chunk)
|
communication_func(scheduled_node.chunk)
|
||||||
if scheduled_node.type == "F":
|
if scheduled_node.type == "F":
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
@ -770,7 +770,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
)
|
)
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_func_map(scheduled_node.type)
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
communication_func(scheduled_node.chunk)
|
communication_func(scheduled_node.chunk)
|
||||||
|
|
||||||
if scheduled_node.type == "F":
|
if scheduled_node.type == "F":
|
||||||
|
Loading…
Reference in New Issue
Block a user