mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
fix typo tests/ (#3936)
This commit is contained in:
@@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
|
||||
# the index of bn node in computation graph
|
||||
node_index = 1
|
||||
# the total number of bn strategies without sync bn mode
|
||||
# TODO: add sync bn stategies after related passes ready
|
||||
# TODO: add sync bn strategies after related passes ready
|
||||
strategy_number = 4
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
|
@@ -43,14 +43,14 @@ def test_output_handler(output_option):
|
||||
output_strategies_vector = StrategiesVector(output_node)
|
||||
|
||||
# build handler
|
||||
otuput_handler = OutputHandler(node=output_node,
|
||||
output_handler = OutputHandler(node=output_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=output_strategies_vector,
|
||||
output_option=output_option)
|
||||
|
||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
||||
output_handler.register_strategy(compute_resharding_cost=False)
|
||||
# check operation data mapping
|
||||
mapping = otuput_handler.get_operation_data_mapping()
|
||||
mapping = output_handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
@@ -59,7 +59,7 @@ def test_output_handler(output_option):
|
||||
|
||||
assert mapping['output'].name == "output"
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
||||
strategy_name_list = [val.name for val in output_handler.strategies_vector]
|
||||
if output_option == 'distributed':
|
||||
assert "Distributed Output" in strategy_name_list
|
||||
else:
|
||||
|
Reference in New Issue
Block a user