mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[pipeline] A more general _communicate in p2p (#5062)
* A more general _communicate * feat: finish tree_flatten version p2p * fix: update p2p api calls --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
@@ -130,7 +130,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
@@ -149,7 +149,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
@@ -206,7 +206,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
@@ -238,7 +238,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
|
Reference in New Issue
Block a user