[autoparallel] fix C version rotor inconsistency (#1691)

This commit is contained in:
Boyuan Yao
2022-10-12 15:21:58 +08:00
committed by GitHub
parent 363fc2861a
commit 31d2f03d27
3 changed files with 54 additions and 22 deletions

View File

@@ -26,7 +26,7 @@ except:
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]:
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
data = torch.rand(128, 3, 224, 224, device='meta')
@@ -41,15 +41,24 @@ def _run_C_solver_consistency_test(rank=0):
# python solver
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
opt_python = copy.deepcopy(gm.__opttable__)
# C solver
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
opt_C = copy.deepcopy(gm.__opttable__)
# make sure the opt_tables are the same
for m in range(len(opt_python)):
for d in range(1, len(opt_python[0])):
for i in range(len(opt_python[0]) - d):
assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \
f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}"
sequence_python = sequence_python.list_operations()
sequence_C = sequence_C.list_operations()
# make sure the solutions are the same
# make sure the sequences are the same
assert len(sequence_python) == len(sequence_C) and \
all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))