mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -13,6 +13,15 @@ from ._helper import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
|
||||
'sync_states', 'moe_set_seed', 'reset_seeds'
|
||||
"seed",
|
||||
"set_mode",
|
||||
"with_seed",
|
||||
"add_seed",
|
||||
"get_seeds",
|
||||
"get_states",
|
||||
"get_current_mode",
|
||||
"set_seed_states",
|
||||
"sync_states",
|
||||
"moe_set_seed",
|
||||
"reset_seeds",
|
||||
]
|
||||
|
@@ -100,7 +100,7 @@ def sync_states():
|
||||
|
||||
@contextmanager
|
||||
def seed(parallel_mode: ParallelMode):
|
||||
""" A context for seed switch
|
||||
"""A context for seed switch
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -162,6 +162,7 @@ def with_seed(func, parallel_mode: ParallelMode):
|
||||
def moe_set_seed(seed):
|
||||
if torch.cuda.is_available():
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
global_rank = gpc.get_global_rank()
|
||||
diff_seed = seed + global_rank
|
||||
add_seed(ParallelMode.TENSOR, diff_seed, True)
|
||||
|
@@ -42,7 +42,7 @@ class SeedManager:
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
|
||||
"""
|
||||
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
|
||||
assert parallel_mode in self._seed_states, f"Parallel mode {parallel_mode} is not found in the seed manager"
|
||||
self._seed_states[parallel_mode] = state
|
||||
|
||||
def set_mode(self, parallel_mode: ParallelMode):
|
||||
@@ -71,9 +71,9 @@ class SeedManager:
|
||||
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode`
|
||||
or the seed for `parallel_mode` has been added.
|
||||
"""
|
||||
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||
assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided"
|
||||
if overwrite is False:
|
||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
||||
assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added"
|
||||
elif parallel_mode in self._seed_states:
|
||||
print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True)
|
||||
|
||||
|
Reference in New Issue
Block a user