ColossalAI/colossalai/legacy/nn/parallel/layers/embedding.py
Hongxin Liu 554aa9592e
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
2023-09-11 16:24:28 +08:00

38 lines
1.1 KiB
Python

from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
class ColoEmbedding(ColoModule):
def __init__(self):
super(ColoEmbedding, self).__init__()
self._register_shard_params(['weight'])
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern:
self._set_TP1D(pg)
def _set_TP1D(self, pg: ProcessGroup):
# TP1D Row Linear
_compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([0], [pg.tp_world_size()]),
},
mode='row',
)
# TP1D Col Linear
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([-1], [pg.tp_world_size()]),
},
mode='col',
)
self._set_default(compute_pattern=_compute_pattern, target_mode='row')