[autoparallel] use metainfo in handler (#2149)

This commit is contained in:
YuliangLiu0306
2022-12-20 10:31:22 +08:00
committed by GitHub
parent 9b39170a5c
commit 1cce6e36ca
11 changed files with 105 additions and 31 deletions

View File

@@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
"""
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# check if conv has bias
if len(weight_tensors) > 1: