mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[autoparallel] fix param hook issue in transform pass (#1755)
This commit is contained in:
parent
262652c8bc
commit
d2fc067231
@ -70,10 +70,14 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
|||||||
comm_spec_to_use = comm_action.comm_spec
|
comm_spec_to_use = comm_action.comm_spec
|
||||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||||
|
|
||||||
def hook_fn(grad):
|
def wrapper(param, comm_spec):
|
||||||
_all_reduce(grad, comm_spec_to_use)
|
|
||||||
|
|
||||||
param_sharded.register_hook(hook_fn)
|
def hook_fn(grad):
|
||||||
|
_all_reduce(grad, comm_spec)
|
||||||
|
|
||||||
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
|
wrapper(param_sharded, comm_spec_to_use)
|
||||||
|
|
||||||
sharded_buffer_dict = {}
|
sharded_buffer_dict = {}
|
||||||
for name, buffer in target_module.named_buffers():
|
for name, buffer in target_module.named_buffers():
|
||||||
|
@ -171,22 +171,92 @@ def check_apply_bottleneck(rank, world_size, port):
|
|||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
origin_output.sum().backward()
|
origin_output.sum().backward()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
|
print(
|
||||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 0, 1)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
|
||||||
|
|
||||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||||
|
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
|
||||||
|
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
|
||||||
|
|
||||||
if rank == 1:
|
if rank == 1:
|
||||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
|
print(
|
||||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 1, 1)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
|
||||||
|
|
||||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||||
|
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
|
||||||
|
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
|
||||||
|
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
|
print(
|
||||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 2, 1)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
|
||||||
|
|
||||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||||
|
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
|
||||||
|
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
|
||||||
|
|
||||||
if rank == 3:
|
if rank == 3:
|
||||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
|
print(
|
||||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 3, 1)).abs().sum()}"
|
||||||
|
)
|
||||||
|
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
|
||||||
|
|
||||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||||
|
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
|
||||||
|
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
Loading…
Reference in New Issue
Block a user