[autoparallel] fix bias addition module (#1800)

This commit is contained in:
YuliangLiu0306
2022-11-08 16:21:25 +08:00
committed by GitHub
parent 6e9730d7ab
commit f6032ddb17
9 changed files with 438 additions and 20 deletions

View File

@@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
user_node.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
@@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if op_data.type == OperationDataType.PARAM:
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
if op_data.type == OperationDataType.OUTPUT:
comm_object = node
elif comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
@@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
node.args = tuple(new_args)
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
@@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
return gm

View File

@@ -5,7 +5,12 @@ import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@@ -42,7 +47,32 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
new_sharding_spec = target_sharding_specs[0]
user_node = node.strategies_vector.successor_nodes[0]
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
origin_node_sharding_spec_dict[index] = new_sharding_spec
origin_pending_strategy = node.best_strategy
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
new_sharding_specs = origin_pending_strategy.sharding_specs
new_sharding_specs[origin_op_data] = new_sharding_spec
new_communication_actions = {}
if op_data_in_user in user_strategy.communication_actions:
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
new_communication_action.arg_index = 0
new_communication_actions[origin_op_data] = new_communication_action
new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence),
sharding_specs=new_sharding_specs,
compute_cost=origin_pending_strategy.compute_cost,
communication_cost=origin_pending_strategy.communication_cost,
memory_cost=origin_pending_strategy.memory_cost,
communication_actions=new_communication_actions)
setattr(node, 'best_strategy', new_strategy)
setattr(node, 'sharding_spec', new_sharding_spec)
comm_action_dict = {}
for op_data, comm_action in node.best_strategy.communication_actions.items():
comm_action_dict[op_data.name] = comm_action
@@ -111,6 +141,43 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr':
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
if attr_len == 1:
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root.get_submodule(atoms[-2])
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())
else:
target_sharded = target
setattr(target_module, atoms[-1], target_sharded)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec):
def hook_fn(grad):
_all_reduce(grad, comm_spec)
param.register_hook(hook_fn)
wrapper(target_sharded, comm_spec_to_use)
return gm

View File

@@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
type=data_type,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

View File

@@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
else:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
@@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
communication_action_mapping["input"] = input_comm_action
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)