mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[autoparallel] fix C version rotor inconsistency (#1691)
This commit is contained in:
@@ -26,7 +26,7 @@ except:
|
||||
def _run_C_solver_consistency_test(rank=0):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]:
|
||||
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
|
||||
model = M()
|
||||
data = torch.rand(128, 3, 224, 224, device='meta')
|
||||
|
||||
@@ -41,15 +41,24 @@ def _run_C_solver_consistency_test(rank=0):
|
||||
# python solver
|
||||
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
|
||||
sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
|
||||
opt_python = copy.deepcopy(gm.__opttable__)
|
||||
|
||||
# C solver
|
||||
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
|
||||
sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
|
||||
opt_C = copy.deepcopy(gm.__opttable__)
|
||||
|
||||
# make sure the opt_tables are the same
|
||||
for m in range(len(opt_python)):
|
||||
for d in range(1, len(opt_python[0])):
|
||||
for i in range(len(opt_python[0]) - d):
|
||||
assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \
|
||||
f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}"
|
||||
|
||||
sequence_python = sequence_python.list_operations()
|
||||
sequence_C = sequence_C.list_operations()
|
||||
|
||||
# make sure the solutions are the same
|
||||
# make sure the sequences are the same
|
||||
assert len(sequence_python) == len(sequence_C) and \
|
||||
all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))
|
||||
|
||||
|
Reference in New Issue
Block a user