mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
Migrated project
This commit is contained in:
8
colossalai/context/random/__init__.py
Normal file
8
colossalai/context/random/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from ._helper import (seed, set_mode, with_seed, add_seed,
|
||||
get_seeds, get_states, get_current_mode,
|
||||
set_seed_states, sync_states)
|
||||
|
||||
__all__ = [
|
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
|
||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states'
|
||||
]
|
||||
144
colossalai/context/random/_helper.py
Normal file
144
colossalai/context/random/_helper.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.cuda
|
||||
from torch import Tensor
|
||||
|
||||
from .seed_manager import SeedManager
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
_SEED_MANAGER = SeedManager()
|
||||
|
||||
|
||||
def get_seeds():
|
||||
"""Returns the seeds of the seed manager.
|
||||
|
||||
:return: The seeds of the seed manager
|
||||
:rtype: dict
|
||||
"""
|
||||
return _SEED_MANAGER.seeds
|
||||
|
||||
|
||||
def get_states(copy=False):
|
||||
"""Returns the seed states of the seed manager.
|
||||
|
||||
:return: The seed states of the seed manager
|
||||
:rtype: dict
|
||||
"""
|
||||
states = _SEED_MANAGER.seed_states
|
||||
|
||||
if copy:
|
||||
new_states = dict()
|
||||
|
||||
for parallel_mode, state in states.items():
|
||||
new_states[parallel_mode] = state.clone()
|
||||
return new_states
|
||||
else:
|
||||
return _SEED_MANAGER.seed_states
|
||||
|
||||
|
||||
def get_current_mode():
|
||||
"""Returns the current mode of the seed manager.
|
||||
|
||||
:return: The current mode of the seed manager.
|
||||
:rtype: :class:`torch.ByteTensor`
|
||||
"""
|
||||
return _SEED_MANAGER.current_mode
|
||||
|
||||
|
||||
def add_seed(parallel_mode: ParallelMode, seed: int):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param seed: The seed to be added
|
||||
:type seed: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||
"""
|
||||
_SEED_MANAGER.add_seed(parallel_mode, seed)
|
||||
|
||||
|
||||
def set_mode(parallel_mode: ParallelMode):
|
||||
"""Sets the current mode of the seed manager.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
_SEED_MANAGER.set_mode(parallel_mode)
|
||||
|
||||
|
||||
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
|
||||
"""Sets the state of the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param state: the state to be set
|
||||
:type state: :class:`torch.Tensor`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
|
||||
"""
|
||||
_SEED_MANAGER.set_state(parallel_mode, state)
|
||||
|
||||
|
||||
def sync_states():
|
||||
current_mode = get_current_mode()
|
||||
current_states = torch.cuda.get_rng_state()
|
||||
set_seed_states(current_mode, current_states)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seed(parallel_mode: ParallelMode):
|
||||
""" A context for seed switch
|
||||
|
||||
Examples::
|
||||
|
||||
with seed(ParallelMode.DATA):
|
||||
output = F.dropout(input)
|
||||
|
||||
"""
|
||||
try:
|
||||
# set to new mode
|
||||
current_mode = _SEED_MANAGER.current_mode
|
||||
yield _SEED_MANAGER.set_mode(parallel_mode)
|
||||
finally:
|
||||
# recover
|
||||
_SEED_MANAGER.set_mode(current_mode)
|
||||
|
||||
|
||||
def with_seed(func, parallel_mode: ParallelMode):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
|
||||
Examples::
|
||||
|
||||
# use with decorator
|
||||
@with_seed(ParallelMode.DATA)
|
||||
def forward(input):
|
||||
return F.dropout(input)
|
||||
out = forward(input)
|
||||
# OR use it inline
|
||||
def forward(input):
|
||||
return F.dropout(input)
|
||||
wrapper_forward = with_seed(forward, ParallelMode.DATA)
|
||||
out = wrapped_forward(input)
|
||||
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# switch mode
|
||||
current_mode = _SEED_MANAGER.current_mode
|
||||
_SEED_MANAGER.set_mode(parallel_mode)
|
||||
|
||||
# exec func
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# recover state
|
||||
_SEED_MANAGER.set_mode(current_mode)
|
||||
|
||||
return out
|
||||
|
||||
return wrapper
|
||||
74
colossalai/context/random/seed_manager.py
Normal file
74
colossalai/context/random/seed_manager.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""This class is a manager of all random seeds involved in the system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._current_mode = None
|
||||
self._seeds = dict()
|
||||
self._seed_states = dict()
|
||||
|
||||
@property
|
||||
def current_mode(self):
|
||||
return self._current_mode
|
||||
|
||||
@property
|
||||
def seeds(self):
|
||||
return self._seeds
|
||||
|
||||
@property
|
||||
def seed_states(self):
|
||||
return self._seed_states
|
||||
|
||||
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
|
||||
"""Sets the state of the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param state: the state to be set
|
||||
:type state: :class:`torch.Tensor`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
|
||||
"""
|
||||
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
|
||||
self._seed_states[parallel_mode] = state
|
||||
|
||||
def set_mode(self, parallel_mode: ParallelMode):
|
||||
"""Sets the current mode of the seed manager.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
if self.current_mode:
|
||||
# save the current state for current mode
|
||||
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
|
||||
|
||||
# set the new state for new mode
|
||||
self._current_mode = parallel_mode
|
||||
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
|
||||
|
||||
def add_seed(self, parallel_mode: ParallelMode, seed: int):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param seed: The seed to be added
|
||||
:type seed: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||
"""
|
||||
assert isinstance(
|
||||
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.manual_seed(seed)
|
||||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
||||
self._seeds[parallel_mode] = seed
|
||||
torch.cuda.set_rng_state(current_state)
|
||||
Reference in New Issue
Block a user