fixed bug in activation checkpointing test (#387)

This commit is contained in:
Frank Lee
2022-03-11 14:48:11 +08:00
parent 3af13a2c3e
commit 1e4bf85cdb
4 changed files with 25 additions and 14 deletions

View File

@@ -1,9 +1,7 @@
from ._helper import (seed, set_mode, with_seed, add_seed,
get_seeds, get_states, get_current_mode,
set_seed_states, sync_states, moe_set_seed)
from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states,
sync_states, moe_set_seed, reset_seeds)
__all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states',
'moe_set_seed'
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
'sync_states', 'moe_set_seed', 'reset_seeds'
]

View File

@@ -154,4 +154,9 @@ def moe_set_seed(seed):
global_rank = gpc.get_global_rank()
add_seed(ParallelMode.TENSOR, global_rank, True)
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
f"tensor seed {global_rank}", flush=True)
f"tensor seed {global_rank}",
flush=True)
def reset_seeds():
_SEED_MANAGER.reset()

View File

@@ -66,8 +66,7 @@ class SeedManager:
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
"""
assert isinstance(
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrtie is False:
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
elif parallel_mode in self._seed_states:
@@ -78,3 +77,8 @@ class SeedManager:
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
self._seeds[parallel_mode] = seed
torch.cuda.set_rng_state(current_state)
def reset(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()