mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[autoparallel] adapt solver with self attention (#2037)
* [autoparallel] adapt solver with self attention * polish code
This commit is contained in:
@@ -9,7 +9,14 @@ from torch.fx.node import Node
|
||||
from colossalai.tensor.shape_consistency import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
|
||||
from .constants import (
|
||||
BCAST_FUNC_OP,
|
||||
ELEMENTWISE_FUNC_OP,
|
||||
ELEMENTWISE_METHOD_OP,
|
||||
ELEMENTWISE_MODULE_OP,
|
||||
RESHAPE_FUNC_OP,
|
||||
RESHAPE_METHOD_OP,
|
||||
)
|
||||
|
||||
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
@@ -249,8 +256,15 @@ class StrategiesVector(list):
|
||||
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
||||
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
||||
merge_label = True
|
||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||
# we could merge reshape op, because their computation costs are negligible.
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_method':
|
||||
# we could merge reshape op, because their computation costs are negligible.
|
||||
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
|
||||
if method in RESHAPE_METHOD_OP:
|
||||
merge_label = True
|
||||
if method in ELEMENTWISE_METHOD_OP:
|
||||
merge_label = True
|
||||
return merge_label
|
||||
|
Reference in New Issue
Block a user