mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 21:17:08 +00:00
[pipeline]refactor ppschedule to support tensor list (#1050)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* refactor ppschedule to support tensor list
* polish
This commit is contained in:
@@ -9,14 +9,21 @@ from typing import Union, List, Tuple
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
|
||||
"""Sends tensor meta information before sending a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be sent before communications. This function
|
||||
synchronizes with :func:`recv_tensor_meta`.
|
||||
def send_meta_helper(obj, next_rank, tensor_kwargs):
|
||||
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
|
||||
dist.send(send_ndims, next_rank)
|
||||
dist.send(send_shape, next_rank)
|
||||
|
||||
|
||||
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||
"""Sends obj meta information before sending a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be sent before communications. This function
|
||||
synchronizes with :func:`recv_obj_meta`.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): Tensor to be sent.
|
||||
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
|
||||
need_meta (bool, optional): If False, meta information won't be sent.
|
||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||
|
||||
@@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
|
||||
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
|
||||
dist.send(send_ndims, next_rank)
|
||||
dist.send(send_shape, next_rank)
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
send_meta_helper(obj, next_rank, tensor_kwargs)
|
||||
else:
|
||||
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
for tensor_to_send in obj:
|
||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
|
||||
"""Receives tensor meta information before receiving a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be received before communications. This function
|
||||
synchronizes with :func:`send_tensor_meta`.
|
||||
def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
return recv_shape
|
||||
|
||||
|
||||
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
||||
"""Receives obj meta information before receiving a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be received before communications. This function
|
||||
synchronizes with :func:`send_obj_meta`.
|
||||
|
||||
Args:
|
||||
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
||||
prev_rank (int): The rank of the source of the tensor.
|
||||
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
|
||||
prev_rank (int): The rank of the source of the obj.
|
||||
|
||||
Returns:
|
||||
:class:`torch.Size`: The shape of the tensor to be received.
|
||||
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
||||
"""
|
||||
if tensor_shape is None:
|
||||
if obj_shape is None:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape = torch.Size(recv_shape)
|
||||
else:
|
||||
obj_shape = []
|
||||
for i in range(recv_obj_nums.item()):
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape.append(torch.Size(recv_shape))
|
||||
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
|
||||
tensor_shape = torch.Size(recv_shape)
|
||||
|
||||
return tensor_shape
|
||||
return obj_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
||||
|
Reference in New Issue
Block a user