add overlap option (#2613)

This commit is contained in:
YuliangLiu0306
2023-02-08 15:02:31 +08:00
committed by GitHub
parent cb3d1bef62
commit 28398f1c70
2 changed files with 32 additions and 16 deletions

View File

@@ -352,7 +352,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
@@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
@@ -441,15 +444,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use, reduction_stream)
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
return gm
@@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None):
strategies_constructor: StrategiesConstructor = None,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh)
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict