mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[NFC]fix typo colossalai/auto_parallel nn utils etc. (#3779)
* fix typo colossalai/autochunk auto_parallel amp * fix typo colossalai/auto_parallel nn utils etc.
This commit is contained in:
@@ -6,12 +6,12 @@ import torch
|
||||
|
||||
class PreviousStatus(Enum):
|
||||
"""
|
||||
This class shows the status of previous comparision.
|
||||
This class shows the status of previous comparison.
|
||||
"""
|
||||
RESET = 0
|
||||
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
|
||||
# ORIGIN means the dimension size of original tensor is larger in the previous comparison.
|
||||
ORIGIN = 1
|
||||
# TGT means the dimension size of target tensor is larger in the previous comparision.
|
||||
# TGT means the dimension size of target tensor is larger in the previous comparison.
|
||||
TGT = 2
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
|
||||
tgt_index += 1
|
||||
|
||||
if previous_label == PreviousStatus.TGT:
|
||||
# if the target dimension size is larger in the previous comparision, which means
|
||||
# if the target dimension size is larger in the previous comparison, which means
|
||||
# the origin dimension size has already accumulated larger than target dimension size, so
|
||||
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
|
||||
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||
@@ -111,7 +111,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
|
||||
origin_index += 1
|
||||
|
||||
if previous_label == PreviousStatus.ORIGIN:
|
||||
# if the origin element is larger in the previous comparision, which means
|
||||
# if the origin element is larger in the previous comparison, which means
|
||||
# the target element has already accumulated larger than origin element, so
|
||||
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
|
||||
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||
@@ -139,7 +139,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
|
||||
Rule:
|
||||
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
|
||||
the function will return false.
|
||||
To illustrate this issue, there are two cases to analyse:
|
||||
To illustrate this issue, there are two cases to analyze:
|
||||
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
|
||||
operation without distributed tensor.
|
||||
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
|
||||
|
Reference in New Issue
Block a user