[autoparallel] remove redundancy comm node (#1893)

This commit is contained in:
YuliangLiu0306
2022-11-15 10:53:41 +08:00
committed by GitHub
parent 9183e0dec5
commit 36c0f3ea5b
5 changed files with 23 additions and 20 deletions

View File

@@ -47,6 +47,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
@@ -95,7 +96,8 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
@@ -122,7 +124,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
def wrapper(param, comm_spec):
def hook_fn(grad):
_all_reduce(grad, comm_spec)
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
@@ -172,7 +174,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
def wrapper(param, comm_spec):
def hook_fn(grad):
_all_reduce(grad, comm_spec)
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)