mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[refactory] add nn.parallel module (#1068)
This commit is contained in:
56
colossalai/nn/parallel/layers/colo_module.py
Normal file
56
colossalai/nn/parallel/layers/colo_module.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from colossalai.tensor import ComputePattern
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class ColoModule(object):
|
||||
|
||||
def __init__(self):
|
||||
self._shard_params: List[str] = []
|
||||
# Example:
|
||||
# {ComputePattern.TP1D:
|
||||
# 'default':
|
||||
# 'weight':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'bias':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'row': ...
|
||||
# 'col': ...
|
||||
# }
|
||||
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
||||
|
||||
def _register_shard_params(self, params: List[str]):
|
||||
self._shard_params = params
|
||||
|
||||
def _register_allowed_patterns(self,
|
||||
compute_pattern: ComputePattern,
|
||||
dist_specs: Dict[str, _DistSpec],
|
||||
mode='default'):
|
||||
assert list(
|
||||
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
self._allowed_patterns[compute_pattern] = {}
|
||||
self._allowed_patterns[compute_pattern][mode] = dist_specs
|
||||
|
||||
def _set_default(self, compute_pattern: ComputePattern, target_mode):
|
||||
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return compute_pattern in self._allowed_patterns
|
||||
|
||||
def get_dist_specs(self, compute_pattern: ComputePattern):
|
||||
assert self.has_compute_pattern(compute_pattern)
|
||||
return self._allowed_patterns[compute_pattern]
|
||||
|
||||
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
|
||||
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
|
||||
|
||||
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
|
||||
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
|
||||
return self._allowed_patterns[compute_pattern][mode]
|
||||
|
||||
def get_param_names(self):
|
||||
return self._shard_params
|
||||
|
||||
def register(self, compute_pattern):
|
||||
raise NotImplementedError
|
Reference in New Issue
Block a user