mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[autoparallel] fix param hook issue in transform pass (#1755)
This commit is contained in:
@@ -70,10 +70,14 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec_to_use)
|
||||
def wrapper(param, comm_spec):
|
||||
|
||||
param_sharded.register_hook(hook_fn)
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param_sharded, comm_spec_to_use)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
for name, buffer in target_module.named_buffers():
|
||||
|
Reference in New Issue
Block a user