[autoparallel] accelerate gpt2 training (#2495)

This commit is contained in:
YuliangLiu0306
2023-01-29 11:13:15 +08:00
committed by GitHub
parent a360b9bc44
commit aa0f6686f9
5 changed files with 21 additions and 17 deletions

View File

@@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec):
def wrapper(param, comm_spec, stream):
def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False)
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use)
wrapper(param, comm_spec_to_use, reduction_stream)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
@@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec):
def wrapper(param, comm_spec, stream):
def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False)
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use)
wrapper(target, comm_spec_to_use, reduction_stream)
return gm