[autoparallel] adapt solver with self attention (#2037)

* [autoparallel] adapt solver with self attention

* polish code
This commit is contained in:
YuliangLiu0306
2022-12-01 17:53:15 +08:00
committed by GitHub
parent d3499c98d4
commit 1c1fe44305
6 changed files with 320 additions and 13 deletions

View File

@@ -9,7 +9,14 @@ from torch.fx.node import Node
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
RESHAPE_FUNC_OP,
RESHAPE_METHOD_OP,
)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@@ -249,8 +256,15 @@ class StrategiesVector(list):
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
# we could merge reshape op, because their computation costs are negligible.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
if self.node.op == 'call_method':
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
merge_label = True
if method in ELEMENTWISE_METHOD_OP:
merge_label = True
return merge_label