fix typo tests/ (#3936)

This commit is contained in:
digger yu
2023-06-09 09:49:41 +08:00
committed by GitHub
parent bd2c7c3297
commit e61ffc77c6
4 changed files with 8 additions and 8 deletions

View File

@@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
# the index of bn node in computation graph
node_index = 1
# the total number of bn strategies without sync bn mode
# TODO: add sync bn stategies after related passes ready
# TODO: add sync bn strategies after related passes ready
strategy_number = 4
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,

View File

@@ -43,14 +43,14 @@ def test_output_handler(output_option):
output_strategies_vector = StrategiesVector(output_node)
# build handler
otuput_handler = OutputHandler(node=output_node,
output_handler = OutputHandler(node=output_node,
device_mesh=device_mesh,
strategies_vector=output_strategies_vector,
output_option=output_option)
otuput_handler.register_strategy(compute_resharding_cost=False)
output_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = otuput_handler.get_operation_data_mapping()
mapping = output_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
@@ -59,7 +59,7 @@ def test_output_handler(output_option):
assert mapping['output'].name == "output"
assert mapping['output'].type == OperationDataType.OUTPUT
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
strategy_name_list = [val.name for val in output_handler.strategies_vector]
if output_option == 'distributed':
assert "Distributed Output" in strategy_name_list
else: