mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[tensor] a shorter shard and replicate spec (#1245)
This commit is contained in:
@@ -7,16 +7,6 @@ class ColoModule(object):
|
||||
|
||||
def __init__(self):
|
||||
self._shard_params: List[str] = []
|
||||
# Example:
|
||||
# {ComputePattern.TP1D:
|
||||
# 'default':
|
||||
# 'weight':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'bias':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'row': ...
|
||||
# 'col': ...
|
||||
# }
|
||||
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
||||
|
||||
def _register_shard_params(self, params: List[str]):
|
||||
|
@@ -1,7 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
|
||||
|
||||
|
||||
class ColoEmbedding(ColoModule):
|
||||
@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard([0], [pg.tp_world_size()])
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
'bias': ShardSpec([0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
Reference in New Issue
Block a user