mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[autoparallel] remove redundancy comm node (#1893)
This commit is contained in:
@@ -47,6 +47,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
@@ -95,7 +96,8 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
# This stream is created for overlaping the communication and computation.
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
@@ -122,7 +124,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||
def wrapper(param, comm_spec):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec)
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
@@ -172,7 +174,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||
def wrapper(param, comm_spec):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec)
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
|
Reference in New Issue
Block a user