mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[autoparallel] fix bias addition module (#1800)
This commit is contained in:
@@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler):
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
type=data_type,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
|
@@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
|
||||
arg_index=0)
|
||||
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||
|
||||
else:
|
||||
elif len(total_mesh_dim_list) >= 2:
|
||||
source_spec = sharding_spec_mapping["input"]
|
||||
target_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=source_spec.entire_shape,
|
||||
@@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
|
||||
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
else:
|
||||
input_comm_action = None
|
||||
|
||||
if input_comm_action is not None:
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
Reference in New Issue
Block a user