mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -12,14 +12,13 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||
meta information of the tensor should be sent before communications. This function
|
||||
synchronizes with :func:`recv_tensor_meta`.
|
||||
|
||||
:param tensor: Tensor to be sent
|
||||
:param need_meta: If False, meta information won't be sent
|
||||
:param next_rank: The rank of the next member in pipeline parallel group
|
||||
:type tensor: Tensor
|
||||
:type need_meta: bool, optional
|
||||
:type next_rank: int
|
||||
:return: False
|
||||
:rtype: bool
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to be sent.
|
||||
need_meta (bool, optional): If False, meta information won't be sent.
|
||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||
|
||||
Returns:
|
||||
bool: False
|
||||
"""
|
||||
if need_meta:
|
||||
if next_rank is None:
|
||||
@@ -36,17 +35,17 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||
"""Recieves tensor meta information before recieving a specific tensor.
|
||||
"""Receives tensor meta information before receiving a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be recieved before communications. This function
|
||||
meta information of the tensor should be received before communications. This function
|
||||
synchronizes with :func:`send_tensor_meta`.
|
||||
|
||||
:param tensor_shape: The shape of the tensor to be recieved
|
||||
:param prev_rank: The rank of the source of the tensor
|
||||
:type tensor_shape: torch.Size
|
||||
:type prev_rank: int, optional
|
||||
:return: The shape of the tensor to be recieved
|
||||
:rtype: torch.Size
|
||||
Args:
|
||||
tensor_shape (torch.Size): The shape of the tensor to be received.
|
||||
prev_rank (int): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Size: The shape of the tensor to be received.
|
||||
"""
|
||||
if tensor_shape is None:
|
||||
if prev_rank is None:
|
||||
@@ -67,14 +66,12 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
"""Break a tensor into equal 1D chunks.
|
||||
|
||||
:param tensor: Tensor to be splitted before communication
|
||||
:param new_buffer: Whether uses a new buffer to store sliced tensor
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to be split before communication.
|
||||
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
||||
|
||||
:type tensor: torch.Tensor
|
||||
:type new_buffer: bool, optional
|
||||
|
||||
:return splitted_tensor: The splitted tensor
|
||||
:rtype splitted_tensor: torch.Tensor
|
||||
Returns:
|
||||
torch.Tensor: The split tensor
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -92,11 +89,10 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
def gather_split_1d_tensor(tensor):
|
||||
"""Opposite of above function, gather values from model parallel ranks.
|
||||
|
||||
:param tensor: Tensor to be gathered after communication
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return gathered: The gathered tensor
|
||||
:rtype gathered: torch.Tensor
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to be gathered after communication.
|
||||
Returns:
|
||||
gathered (torch.Tensor): The gathered tensor
|
||||
"""
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
numel = torch.numel(tensor)
|
||||
|
Reference in New Issue
Block a user