[autoparallel] added sharding spec conversion for linear handler (#1687)

This commit is contained in:
Frank Lee
2022-10-12 11:16:18 +08:00
committed by GitHub
parent af718e83f2
commit 4973157ad7
6 changed files with 222 additions and 43 deletions

View File

@@ -6,6 +6,8 @@ from enum import Enum
from functools import reduce
import operator
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
@@ -136,6 +138,10 @@ class _DimSpec:
return difference
class ShardingException(Exception):
pass
class ShardingSpec:
'''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong