[autoparallel] Add metainfo support for F.linear (#1987)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator
This commit is contained in:
Boyuan Yao
2022-11-23 14:12:34 +08:00
committed by GitHub
parent 2edbef13cc
commit 6cd784ffee
4 changed files with 82 additions and 15 deletions

View File

@@ -19,10 +19,13 @@ from ..registry import meta_register
__all__ = ['linear_meta_info']
@meta_register.register(torch.nn.functional.linear)
@meta_register.register(torch.nn.Linear)
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear meta info generator
The atens graph of torch.nn.Linear with bias is
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
@@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
# process the dimension of input and output
if len(input_tensor.shape) > 2:
@@ -76,9 +79,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
output_tensor: torch.Tensor
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
if len(args) == 4:
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
if len(weight_tensors) > 1:
has_bias = True
if len(weight_tensors[0].shape) == 2:
weight_tensor, bias_tensor = weight_tensors
else:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
if has_bias:
# calculate cost with bias

View File

@@ -92,8 +92,12 @@ class MetaInfo:
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry'
meta_func = meta_register.get(self._target.__class__)
try:
# module
meta_func = meta_register.get(self._target.__class__)
except:
# function
meta_func = meta_register.get(self._target)
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]