[autoparallel] fix bias addition module (#1800)

This commit is contained in:
YuliangLiu0306
2022-11-08 16:21:25 +08:00
committed by GitHub
parent 6e9730d7ab
commit f6032ddb17
9 changed files with 438 additions and 20 deletions

View File

@@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
type=data_type,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

View File

@@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
else:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
@@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
communication_action_mapping["input"] = input_comm_action
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)