[pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Elsa Granger
2024-01-08 15:37:27 +08:00
committed by GitHub
parent 7bc6969ce6
commit d565df3821
4 changed files with 104 additions and 136 deletions

View File

@@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
@@ -130,7 +130,7 @@ class InterleavedSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
@@ -149,7 +149,7 @@ class InterleavedSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
@@ -206,7 +206,7 @@ class InterleavedSchedule(PipelineSchedule):
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
@@ -238,7 +238,7 @@ class InterleavedSchedule(PipelineSchedule):
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only