ColossalAI/colossalai/legacy/nn/parallel/layers/colo_module.py
Hongxin Liu 554aa9592e
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
2023-09-11 16:24:28 +08:00

48 lines
2.0 KiB
Python

from typing import Dict, List
from colossalai.tensor import ComputePattern
from colossalai.tensor.distspec import _DistSpec
class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]):
self._shard_params = params
def _register_allowed_patterns(self,
compute_pattern: ComputePattern,
dist_specs: Dict[str, _DistSpec],
mode='default'):
assert list(
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
if not compute_pattern in self._allowed_patterns:
self._allowed_patterns[compute_pattern] = {}
self._allowed_patterns[compute_pattern][mode] = dist_specs
def _set_default(self, compute_pattern: ComputePattern, target_mode):
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
def has_compute_pattern(self, compute_pattern: ComputePattern):
return compute_pattern in self._allowed_patterns
def get_dist_specs(self, compute_pattern: ComputePattern):
assert self.has_compute_pattern(compute_pattern)
return self._allowed_patterns[compute_pattern]
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
return self._allowed_patterns[compute_pattern][mode]
def get_param_names(self):
return self._shard_params
def register(self, compute_pattern, pg):
raise NotImplementedError