[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -5,7 +5,6 @@ from colossalai.legacy.tensor.distspec import _DistSpec
class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
@@ -13,18 +12,18 @@ class ColoModule(object):
def _register_shard_params(self, params: List[str]):
self._shard_params = params
def _register_allowed_patterns(self,
compute_pattern: ComputePattern,
dist_specs: Dict[str, _DistSpec],
mode='default'):
assert list(
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
def _register_allowed_patterns(
self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default"
):
assert (
list(dist_specs.keys()).sort() == self._shard_params.sort()
), "Every registered param should have dist_spec."
if not compute_pattern in self._allowed_patterns:
self._allowed_patterns[compute_pattern] = {}
self._allowed_patterns[compute_pattern][mode] = dist_specs
def _set_default(self, compute_pattern: ComputePattern, target_mode):
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode]
def has_compute_pattern(self, compute_pattern: ComputePattern):
return compute_pattern in self._allowed_patterns
@@ -33,10 +32,10 @@ class ColoModule(object):
assert self.has_compute_pattern(compute_pattern)
return self._allowed_patterns[compute_pattern]
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"):
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"):
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
return self._allowed_patterns[compute_pattern][mode]