[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)

This commit is contained in:
digger yu
2023-05-19 13:50:00 +08:00
committed by GitHub
parent 21e29e2212
commit 32f81f14d4
6 changed files with 12 additions and 12 deletions

View File

@@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
_is_batch_dims_same = False
# retireve dimensions
# retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]

View File

@@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
def _act_annotation_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""

View File

@@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
@@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)