mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[autoparallel] adapt autoparallel with new analyzer (#3261)
* [autoparallel] adapt autoparallel with new analyzer * fix all node handler tests * polish * polish
This commit is contained in:
@@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
# This stream is created for overlaping the communication and computation.
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
|
||||
def _add_hook_for_grad_communication(node, param):
|
||||
def _add_hook_for_grad_communication(node, param, name=None):
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
|
||||
def _filter_param_to_hook(node, op_data, comm_action):
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
|
||||
def _filter_param_to_hook(node, op_data, comm_action, name):
|
||||
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
return True
|
||||
if node.op == 'get_attr' and isinstance(
|
||||
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
@@ -402,7 +403,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if _filter_param_to_hook(node, operation_data, comm_action):
|
||||
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
@@ -442,7 +443,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
param = _shard_param(param, target_sharding_spec)
|
||||
|
||||
setattr(target_module, name, param)
|
||||
_add_hook_for_grad_communication(node, param)
|
||||
_add_hook_for_grad_communication(node, param, name)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
|
Reference in New Issue
Block a user