mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] added sharding spec conversion for linear handler (#1687)
This commit is contained in:
@@ -6,6 +6,8 @@ from enum import Enum
|
||||
from functools import reduce
|
||||
import operator
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
ALLGATHER_COST = 20
|
||||
SHARD_COST = 5
|
||||
STEP_PENALTY = 6
|
||||
@@ -136,6 +138,10 @@ class _DimSpec:
|
||||
return difference
|
||||
|
||||
|
||||
class ShardingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ShardingSpec:
|
||||
'''
|
||||
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
||||
|
Reference in New Issue
Block a user