mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[pipeline] A more general _communicate in p2p (#5062)
* A more general _communicate * feat: finish tree_flatten version p2p * fix: update p2p api calls --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
@@ -121,7 +121,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
@@ -138,7 +138,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
@@ -188,7 +188,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
@@ -214,7 +214,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
Reference in New Issue
Block a user