[tensor] a shorter shard and replicate spec (#1245)

This commit is contained in:
Jiarui Fang
2022-07-11 15:51:48 +08:00
committed by GitHub
parent 2699dfbbfd
commit 9bcd2fd4af
25 changed files with 91 additions and 98 deletions

View File

@@ -7,16 +7,6 @@ class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]):