mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autoparallel] Add metainfo support for F.linear (#1987)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator
This commit is contained in:
@@ -19,10 +19,13 @@ from ..registry import meta_register
|
||||
__all__ = ['linear_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
@meta_register.register(torch.nn.Linear)
|
||||
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Linear meta info generator
|
||||
The atens graph of torch.nn.Linear with bias is
|
||||
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
|
||||
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
|
||||
but we will hold the bias mechanism in the linear metainfo generator for future use.
|
||||
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
|
||||
@@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
|
||||
# process the dimension of input and output
|
||||
if len(input_tensor.shape) > 2:
|
||||
@@ -76,9 +79,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
output_tensor: torch.Tensor
|
||||
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
|
||||
|
||||
if len(args) == 4:
|
||||
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
|
||||
if len(weight_tensors) > 1:
|
||||
has_bias = True
|
||||
if len(weight_tensors[0].shape) == 2:
|
||||
weight_tensor, bias_tensor = weight_tensors
|
||||
else:
|
||||
bias_tensor, weight_tensor = weight_tensors
|
||||
else:
|
||||
weight_tensor = weight_tensors[0]
|
||||
|
||||
if has_bias:
|
||||
# calculate cost with bias
|
||||
|
@@ -92,8 +92,12 @@ class MetaInfo:
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
|
||||
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry'
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
try:
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
except:
|
||||
# function
|
||||
meta_func = meta_register.get(self._target)
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
Reference in New Issue
Block a user