mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autoparallel] use metainfo in handler (#2149)
This commit is contained in:
@@ -28,7 +28,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
inplace = kwargs.get("inplace", False)
|
||||
|
||||
|
@@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
if len(args) == 4:
|
||||
weight_tensors = [args[1].data, args[3].data]
|
||||
else:
|
||||
weight_tensors = [args[1].data]
|
||||
|
||||
# check if conv has bias
|
||||
if len(weight_tensors) > 1:
|
||||
|
@@ -66,9 +66,13 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
|
||||
input_tensor = args[0].data
|
||||
output_tensor = args[2].data
|
||||
if len(args) == 4:
|
||||
weight_tensors = [args[1].data, args[3].data]
|
||||
else:
|
||||
weight_tensors = [args[1].data]
|
||||
|
||||
# process the dimension of input and output
|
||||
if len(input_tensor.shape) > 2:
|
||||
|
@@ -45,7 +45,7 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
|
||||
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
|
||||
|
@@ -30,7 +30,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
|
||||
# construct forward args for flop mapping
|
||||
|
Reference in New Issue
Block a user