mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] accelerate gpt2 training (#2495)
This commit is contained in:
@@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
# register hook to the parameters
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
def wrapper(param, comm_spec, stream):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param, comm_spec_to_use)
|
||||
wrapper(param, comm_spec_to_use, reduction_stream)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
@@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
# register hook to the parameters
|
||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
def wrapper(param, comm_spec, stream):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target, comm_spec_to_use)
|
||||
wrapper(target, comm_spec_to_use, reduction_stream)
|
||||
return gm
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user