[tensor] a shorter shard and replicate spec (#1245)

This commit is contained in:
Jiarui Fang
2022-07-11 15:51:48 +08:00
committed by GitHub
parent 2699dfbbfd
commit 9bcd2fd4af
25 changed files with 91 additions and 98 deletions

View File

@@ -7,16 +7,6 @@ class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]):

View File

@@ -1,7 +1,5 @@
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoEmbedding(ColoModule):
@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]),
'weight': ShardSpec([0], [pg.tp_world_size()]),
},
mode='row',
)
@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'weight': ShardSpec([-1], [pg.tp_world_size()]),
},
mode='col',
)

View File

@@ -1,5 +1,5 @@
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule):
@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()])
'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': ShardSpec([0], [pg.tp_world_size()])
},
mode='col',
)