Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -12,14 +12,13 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
:param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
:type next_rank: int
:return: False
:rtype: bool
Args:
tensor (torch.Tensor): Tensor to be sent.
need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group.
Returns:
bool: False
"""
if need_meta:
if next_rank is None:
@@ -36,17 +35,17 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor.
"""Receives tensor meta information before receiving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
meta information of the tensor should be received before communications. This function
synchronizes with :func:`send_tensor_meta`.
:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:type tensor_shape: torch.Size
:type prev_rank: int, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
Args:
tensor_shape (torch.Size): The shape of the tensor to be received.
prev_rank (int): The rank of the source of the tensor.
Returns:
torch.Size: The shape of the tensor to be received.
"""
if tensor_shape is None:
if prev_rank is None:
@@ -67,14 +66,12 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks.
:param tensor: Tensor to be splitted before communication
:param new_buffer: Whether uses a new buffer to store sliced tensor
Args:
tensor (torch.Tensor): Tensor to be split before communication.
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
:type tensor: torch.Tensor
:type new_buffer: bool, optional
:return splitted_tensor: The splitted tensor
:rtype splitted_tensor: torch.Tensor
Returns:
torch.Tensor: The split tensor
"""
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -92,11 +89,10 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.
:param tensor: Tensor to be gathered after communication
:type tensor: torch.Tensor
:return gathered: The gathered tensor
:rtype gathered: torch.Tensor
Args:
tensor (torch.Tensor): Tensor to be gathered after communication.
Returns:
gathered (torch.Tensor): The gathered tensor
"""
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor)