[autoparallel] fix param hook issue in transform pass (#1755)

This commit is contained in:
YuliangLiu0306
2022-10-24 13:13:38 +08:00
committed by GitHub
parent 262652c8bc
commit d2fc067231
2 changed files with 85 additions and 11 deletions

View File

@@ -70,10 +70,14 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
comm_spec_to_use = comm_action.comm_spec
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def hook_fn(grad):
_all_reduce(grad, comm_spec_to_use)
def wrapper(param, comm_spec):
param_sharded.register_hook(hook_fn)
def hook_fn(grad):
_all_reduce(grad, comm_spec)
param.register_hook(hook_fn)
wrapper(param_sharded, comm_spec_to_use)
sharded_buffer_dict = {}
for name, buffer in target_module.named_buffers():