[NFC]fix typo colossalai/auto_parallel nn utils etc. (#3779)

* fix typo colossalai/autochunk auto_parallel amp

* fix typo colossalai/auto_parallel nn utils etc.
This commit is contained in:
digger yu
2023-05-23 15:28:20 +08:00
committed by GitHub
parent e871e342b3
commit 9265f2d4d7
16 changed files with 46 additions and 46 deletions

View File

@@ -6,12 +6,12 @@ import torch
class PreviousStatus(Enum):
"""
This class shows the status of previous comparision.
This class shows the status of previous comparison.
"""
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
# ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparision.
# TGT means the dimension size of target tensor is larger in the previous comparison.
TGT = 2
@@ -91,7 +91,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
tgt_index += 1
if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparision, which means
# if the target dimension size is larger in the previous comparison, which means
# the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
@@ -111,7 +111,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
origin_index += 1
if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparision, which means
# if the origin element is larger in the previous comparison, which means
# the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
@@ -139,7 +139,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false.
To illustrate this issue, there are two cases to analyse:
To illustrate this issue, there are two cases to analyze:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape