[pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Elsa Granger
2024-01-08 15:37:27 +08:00
committed by GitHub
parent 7bc6969ce6
commit d565df3821
4 changed files with 104 additions and 136 deletions

View File

@@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
@@ -121,7 +121,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
@@ -138,7 +138,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
@@ -188,7 +188,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
@@ -214,7 +214,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor