mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
Migrated project
This commit is contained in:
5
colossalai/context/__init__.py
Normal file
5
colossalai/context/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import Config
|
||||
from .parallel_context import ParallelContext
|
||||
from .parallel_context import ParallelMode
|
||||
from .process_group_initializer import *
|
||||
from .random import *
|
70
colossalai/context/_utils.py
Normal file
70
colossalai/context/_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import math
|
||||
|
||||
|
||||
def set_parallel_size(obj, config: dict, key: str, attr_name: str):
|
||||
if key in config:
|
||||
ele = config[key]
|
||||
if isinstance(ele, int):
|
||||
setattr(obj, attr_name, ele)
|
||||
elif isinstance(ele, dict):
|
||||
setattr(obj, attr_name, ele['size'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Parallel configuration does not support this kind of argument, please use int or dict"
|
||||
)
|
||||
|
||||
|
||||
def add_tensor_pg(pg_init, mode, size, depth=None):
|
||||
if mode == '1d':
|
||||
pg_init.append(dict(
|
||||
type='Initializer1D',
|
||||
parallel_size=size
|
||||
))
|
||||
elif mode == '2d':
|
||||
dim = math.floor(math.sqrt(size))
|
||||
pg_init.append(dict(
|
||||
type='Initializer2D_Col',
|
||||
summa_dim=dim
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer2D_Row',
|
||||
summa_dim=dim
|
||||
))
|
||||
elif mode == '2.5d':
|
||||
dim = math.floor(math.sqrt(size // depth))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_ROW',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_COL',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_DEP',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_XZ',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
elif mode == '3d':
|
||||
dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5)
|
||||
pg_init.append(dict(
|
||||
type='ParallelInitializer3D_Input',
|
||||
depth=dim
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='ParallelInitializer3D_Weight',
|
||||
depth=dim
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='ParallelInitializer3D_Output',
|
||||
depth=dim
|
||||
))
|
||||
else:
|
||||
raise NotImplementedError("This kind of tensor splitting has not been implemented yet")
|
99
colossalai/context/config.py
Normal file
99
colossalai/context/config.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Config(dict):
|
||||
"""This is a wrapper class for dict objects so that values of which can be
|
||||
accessed as attributes.
|
||||
|
||||
:param config: The dict object to be wrapped
|
||||
:type config: dict
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
if config is not None:
|
||||
for k, v in config.items():
|
||||
self._add_item(k, v)
|
||||
|
||||
def __missing__(self, key):
|
||||
raise KeyError(key)
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
value = super(Config, self).__getitem__(key)
|
||||
return value
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super(Config, self).__setitem__(key, value)
|
||||
|
||||
def _add_item(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
self.__setattr__(key, Config(value))
|
||||
else:
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def update(self, config):
|
||||
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
|
||||
for k, v in config.items():
|
||||
self._add_item(k, v)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def from_file(filename: str):
|
||||
"""Reads a python file and constructs a corresponding :class:`Config` object.
|
||||
|
||||
:param filename: Name of the file to construct the return object
|
||||
:type filename: str
|
||||
:raises AssertionError: Raises an AssertionError if the file does not exist, or the file
|
||||
is not .py file
|
||||
:return: A :class:`Config` object constructed with information in the file
|
||||
:rtype: :class:`Config`
|
||||
"""
|
||||
|
||||
# check config path
|
||||
if isinstance(filename, str):
|
||||
filepath = Path(filename).absolute()
|
||||
elif isinstance(filename, Path):
|
||||
filepath = filename.absolute()
|
||||
|
||||
assert filepath.exists(), f'{filename} is not found, please check your configuration path'
|
||||
|
||||
# check extension
|
||||
extension = filepath.suffix
|
||||
assert extension == '.py', 'only .py files are supported'
|
||||
|
||||
# import the config as module
|
||||
remove_path = False
|
||||
if filepath.parent not in sys.path:
|
||||
sys.path.insert(0, (filepath))
|
||||
remove_path = True
|
||||
|
||||
module_name = filepath.stem
|
||||
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
|
||||
module = source_file.load_module()
|
||||
|
||||
# load into config
|
||||
config = Config()
|
||||
|
||||
for k, v in module.__dict__.items():
|
||||
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
|
||||
continue
|
||||
else:
|
||||
config._add_item(k, v)
|
||||
|
||||
# TODO: replace with logger warning here when logger is done
|
||||
print('warning: variables which starts with __, is a module or class declaration are omitted')
|
||||
|
||||
# remove module
|
||||
del sys.modules[module_name]
|
||||
if remove_path:
|
||||
sys.path.pop(0)
|
||||
|
||||
return config
|
454
colossalai/context/parallel_context.py
Normal file
454
colossalai/context/parallel_context.py
Normal file
@@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from ._utils import set_parallel_size
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
:param args: The distributed arguments in the system
|
||||
:type args: dict
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
self._world_sizes = dict()
|
||||
self._groups = dict()
|
||||
self._ranks_in_group = dict()
|
||||
|
||||
# load config from file
|
||||
self._dist_args = args
|
||||
self._config = None
|
||||
|
||||
# default 3D parallel args, will be overwritten during process group intialization
|
||||
self.world_size = 1
|
||||
self.data_parallel_size = 1
|
||||
self.pipeline_parallel_size = 1
|
||||
self.tensor_parallel_size = 1
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
def load_config(self, config: Union[dict, str]):
|
||||
"""Loads the configuration from either a dict or a file.
|
||||
|
||||
:param config: Either a dict containing the configuration information or the filename
|
||||
of a file containing the configuration information
|
||||
:type config: dict or str
|
||||
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
self._config = Config.from_file(config)
|
||||
elif isinstance(config, dict):
|
||||
self._config = Config(config)
|
||||
else:
|
||||
raise TypeError("Invalid type for config, only dictionary or string is supported")
|
||||
|
||||
def set_dist_args(self, args):
|
||||
"""Sets the distributed arguments.
|
||||
|
||||
:param args: The distributed arguments in the system
|
||||
:type args: dict
|
||||
"""
|
||||
self._dist_args = args
|
||||
|
||||
@staticmethod
|
||||
def _check_parallel_mode(parallel_mode: ParallelMode):
|
||||
assert isinstance(parallel_mode, ParallelMode)
|
||||
|
||||
def get_global_rank(self):
|
||||
"""Returns the global rank of the current device.
|
||||
|
||||
:return: The global rank of the current device
|
||||
:rtype: int
|
||||
"""
|
||||
return self._global_ranks[ParallelMode.GLOBAL]
|
||||
|
||||
def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
|
||||
"""Adds the global rank of the current device for `parallel_mode` to the context.
|
||||
|
||||
:param parallel_mode: The parallel mode for the rank
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param rank: The rank to be added
|
||||
:type rank: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._global_ranks[parallel_mode] = rank
|
||||
|
||||
def get_local_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the local rank of the current device.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The local rank of the current device for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._local_ranks[parallel_mode]
|
||||
|
||||
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
|
||||
"""Adds the local rank of the current device for `parallel_mode` to the context.
|
||||
|
||||
:param parallel_mode: The parallel mode for the rank
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param rank: The rank to be added
|
||||
:type rank: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._local_ranks[parallel_mode] = rank
|
||||
|
||||
def get_next_global_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the global rank of the next device.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The global rank of the next device for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
|
||||
# get rank and world size
|
||||
local_rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
||||
|
||||
return ranks_in_group[(local_rank + 1) % world_size]
|
||||
|
||||
def get_prev_global_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the global rank of the previous device.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The global rank of the previous device for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
|
||||
# get rank and world size
|
||||
local_rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
||||
|
||||
return ranks_in_group[(local_rank - 1) % world_size]
|
||||
|
||||
def is_first_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`
|
||||
:rtype: bool
|
||||
"""
|
||||
rank = self.get_local_rank(parallel_mode)
|
||||
return rank == 0
|
||||
|
||||
def is_last_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether the current device is the last one
|
||||
among its group for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: a boolean value indicating whether the current device is the last one
|
||||
among its group for `parallel_mode`
|
||||
:rtype: bool
|
||||
"""
|
||||
rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
return rank == world_size - 1
|
||||
|
||||
def get_world_size(self, parallel_mode: ParallelMode):
|
||||
"""Returns the world size for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The world size for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._world_sizes[parallel_mode]
|
||||
|
||||
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
|
||||
"""Adds world size for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param world_size: The world size to be added
|
||||
:type world_size: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._world_sizes[parallel_mode] = world_size
|
||||
|
||||
def get_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The group of the current device for `parallel_mode`
|
||||
:rtype: torch.distributed.ProcessGroup
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._groups[parallel_mode]
|
||||
|
||||
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
||||
"""Adds the group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param group: The group to be added
|
||||
:type group: torch.distributed.ProcessGroup
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._groups[parallel_mode] = group
|
||||
|
||||
def get_ranks_in_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the rank of the current device for `parallel_mode` in the group.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: the rank of the current device for `parallel_mode` in the group
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._ranks_in_group[parallel_mode]
|
||||
|
||||
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
|
||||
"""Adds the ranks of the current device for `parallel_mode` in the group.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param ranks: List of ranks to be added
|
||||
:type ranks: list
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._ranks_in_group[parallel_mode] = ranks
|
||||
|
||||
def init_global_dist(self, addr=None, port=None):
|
||||
"""Initializes the global distributed environment.
|
||||
|
||||
:param addr: The IP address of the current device
|
||||
:type addr: str, optional
|
||||
:param port: The port to be used in the system of the current device
|
||||
:type port: int, optional
|
||||
"""
|
||||
# get config
|
||||
rank = self._dist_args.local_rank
|
||||
world_size = self._dist_args.world_size
|
||||
# default env config, overwrite by exporting
|
||||
# them in your bash script
|
||||
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
|
||||
port = os.getenv('MASTER_PORT', '8008') if port is None else port
|
||||
init_method = f'tcp://{addr}:{port}'
|
||||
|
||||
dist.init_process_group(backend=self._dist_args.backend,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
self._register_dist(rank, world_size, None,
|
||||
list(range(world_size)), ParallelMode.GLOBAL)
|
||||
self._global_ranks[ParallelMode.GLOBAL] = rank
|
||||
|
||||
def _register_dist(self, local_rank, world_size,
|
||||
process_group, ranks_in_group, mode):
|
||||
self.add_local_rank(mode, local_rank)
|
||||
self.add_world_size(mode, world_size)
|
||||
self.add_group(mode, process_group)
|
||||
self.add_ranks_in_group(mode, ranks_in_group)
|
||||
|
||||
def check_sanity(self):
|
||||
"""Checks sanity of the parallel context.
|
||||
|
||||
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product
|
||||
of data paralle size, pipeline parallel size and tensor parallel size
|
||||
"""
|
||||
dps = self.data_parallel_size
|
||||
pps = self.pipeline_parallel_size
|
||||
tps = self.tensor_parallel_size
|
||||
ws = self.world_size
|
||||
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
|
||||
|
||||
def init_parallel_groups(self):
|
||||
"""Initializes the parallel groups.
|
||||
|
||||
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file
|
||||
"""
|
||||
|
||||
# get rank and world size
|
||||
rank = self.get_global_rank()
|
||||
world_size = self.get_world_size(ParallelMode.GLOBAL)
|
||||
self.world_size = world_size
|
||||
|
||||
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
|
||||
|
||||
# set parallel size as attributes for global context
|
||||
parallel_config = self.config.parallel
|
||||
set_parallel_size(self, parallel_config, 'pipeline',
|
||||
'pipeline_parallel_size')
|
||||
set_parallel_size(self, parallel_config, 'tensor',
|
||||
'tensor_parallel_size')
|
||||
|
||||
# the user should not set the data parallel size manually
|
||||
# instead, it should be calculated based on other parallel config
|
||||
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
|
||||
|
||||
# get the tensor parallel mode and check
|
||||
tensor_parallel_mode = parallel_config['tensor'].get('mode', None)
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
self.check_sanity()
|
||||
|
||||
pg_init = []
|
||||
# LSG: init data parallel process group for compatibility with other parallel module such as zero
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
|
||||
|
||||
if self.pipeline_parallel_size > 1:
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
|
||||
|
||||
# init specific tensor parallel group
|
||||
if tensor_parallel_mode is not None:
|
||||
tensor_parallel_cfg = parallel_config['tensor'].copy()
|
||||
|
||||
# remove duplicate parameters
|
||||
tensor_parallel_cfg.pop('mode')
|
||||
tensor_parallel_cfg.pop('size')
|
||||
|
||||
# add this config to initialize later
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
|
||||
|
||||
# run initialization of different process groups
|
||||
for initializer_cfg in pg_init:
|
||||
cfg = initializer_cfg.copy()
|
||||
initializer_type = cfg.pop('type')
|
||||
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
|
||||
rank, world_size, self.config,
|
||||
self.data_parallel_size,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
**cfg)
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
for args in parallel_setting:
|
||||
self._register_dist(*args)
|
||||
else:
|
||||
self._register_dist(*parallel_setting)
|
||||
|
||||
def is_initialized(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether `parallel_mode` is initialized
|
||||
in the current system.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:return: a boolean value indicating whether `parallel_mode` is initialized
|
||||
in the current system
|
||||
:rtype: bool
|
||||
"""
|
||||
return parallel_mode in self._groups
|
||||
|
||||
def destroy(self):
|
||||
"""Destroys the current distributed parallel environment.
|
||||
"""
|
||||
for mode, group in self._groups.items():
|
||||
if mode is not ParallelMode.GLOBAL:
|
||||
dist.destroy_process_group(group)
|
||||
# destroy global process group
|
||||
dist.destroy_process_group()
|
||||
|
||||
def set_device(self):
|
||||
"""Sets distributed processes to be bound to devices.
|
||||
"""
|
||||
devices_per_node = torch.cuda.device_count()
|
||||
global_rank = self.get_global_rank()
|
||||
device = global_rank % devices_per_node
|
||||
torch.cuda.set_device(device)
|
||||
print(f'process rank {global_rank} is bound to device {device}')
|
||||
|
||||
def set_seed(self):
|
||||
"""Sets seeds for all random libraries.
|
||||
"""
|
||||
if hasattr(self.config, 'seed'):
|
||||
seed = getattr(self.config, 'seed')
|
||||
else:
|
||||
seed = 2 # default seed
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
global_rank = self.get_global_rank()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# create random seed for different parallel modes
|
||||
# data parallel seed are kept the same
|
||||
parallel_seed = seed
|
||||
add_seed(ParallelMode.DATA, parallel_seed)
|
||||
|
||||
# model parallel seeds are different across ranks
|
||||
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
|
||||
|
||||
# add seed for data parallel and tensor parallel only
|
||||
if self.is_initialized(ParallelMode.TENSOR):
|
||||
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
|
||||
# 100 is only to increase the diff in seeds between pipeline stages
|
||||
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
|
||||
tp_seed = seed + tp_rank_with_offset
|
||||
add_seed(ParallelMode.TENSOR, tp_seed)
|
||||
|
||||
set_mode(ParallelMode.DATA)
|
||||
seeds = get_seeds()
|
||||
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
|
||||
|
||||
print(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.", flush=True)
|
||||
else:
|
||||
print(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True)
|
||||
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
|
||||
flush=True)
|
44
colossalai/context/parallel_mode.py
Normal file
44
colossalai/context/parallel_mode.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# parallel modes
|
||||
class ParallelMode(Enum):
|
||||
"""This is an enumeration class containing all possible parallel modes.
|
||||
"""
|
||||
|
||||
GLOBAL = 'global'
|
||||
|
||||
# common parallel
|
||||
DATA = 'data'
|
||||
|
||||
# pipeline parallel
|
||||
PIPELINE = 'pipe'
|
||||
PIPELINE_PREV = 'pipe_prev'
|
||||
PIPELINE_NEXT = 'pipe_next'
|
||||
|
||||
# containing all ranks in tensor parallel
|
||||
TENSOR = 'tensor'
|
||||
|
||||
# sequence parallel
|
||||
SEQUENCE = 'sequence'
|
||||
|
||||
# 1D Parallel
|
||||
PARALLEL_1D = '1d'
|
||||
|
||||
# 2D parallel
|
||||
PARALLEL_2D_ROW = '2d_row'
|
||||
PARALLEL_2D_COL = '2d_col'
|
||||
|
||||
# 3D parallel
|
||||
PARALLEL_3D_INPUT = '3d_input'
|
||||
PARALLEL_3D_WEIGHT = '3d_weight'
|
||||
PARALLEL_3D_OUTPUT = '3d_output'
|
||||
|
||||
# 2.5D parallel
|
||||
PARALLEL_2P5D_ROW = '2p5d_row'
|
||||
PARALLEL_2P5D_COL = '2p5d_col'
|
||||
PARALLEL_2P5D_DEP = '2p5d_dep'
|
||||
PARALLEL_2P5D_XZ = '2p5d_xz'
|
15
colossalai/context/process_group_initializer/__init__.py
Normal file
15
colossalai/context/process_group_initializer/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .initializer_1d import Initializer_1D
|
||||
from .initializer_2d import Initializer_2D
|
||||
from .initializer_2p5d import Initializer_2p5D
|
||||
from .initializer_3d import Initializer_3D
|
||||
from .initializer_data import Initializer_Data
|
||||
from .initializer_pipeline import Initializer_Pipeline
|
||||
from .initializer_sequence import Initializer_Sequence
|
||||
from .initializer_tensor import Initializer_Tensor
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
__all__ = [
|
||||
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
|
||||
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
|
||||
'Initializer_1D', 'ProcessGroupInitializer'
|
||||
]
|
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_1D(ProcessGroupInitializer):
|
||||
'''A ProcessGroupInitializer for 1d tensor parallelism.
|
||||
'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
:rtype: tuple
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_1D
|
||||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
123
colossalai/context/process_group_initializer/initializer_2d.py
Normal file
123
colossalai/context/process_group_initializer/initializer_2d.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import SUMMA_DIM
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
def _check_summa_env_var(summa_dim):
|
||||
# check environment variable for SUMMA
|
||||
env_summa_dim = os.environ.get(SUMMA_DIM, None)
|
||||
|
||||
if env_summa_dim:
|
||||
assert int(env_summa_dim) == summa_dim, \
|
||||
'SUMMA_DIM has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[SUMMA_DIM] = str(summa_dim)
|
||||
|
||||
|
||||
class Initializer_2D_Row(ProcessGroupInitializer):
|
||||
'''2d tensor parallel initialization among rows.
|
||||
'''
|
||||
|
||||
def __init__(self, num_group, summa_dim, *args, **kwargs):
|
||||
super(Initializer_2D_Row, self).__init__(*args, **kwargs)
|
||||
self.num_group = num_group
|
||||
self.summa_dim = summa_dim
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2D tensor row parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2D_ROW
|
||||
|
||||
for i in range(self.num_group):
|
||||
for j in range(self.summa_dim):
|
||||
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k
|
||||
for k in range(self.summa_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_2D_Col(ProcessGroupInitializer):
|
||||
'''2d tensor parallel initialization among cols.
|
||||
'''
|
||||
|
||||
def __init__(self, num_group, summa_dim, *args, **kwargs):
|
||||
super(Initializer_2D_Col, self).__init__(*args, **kwargs)
|
||||
self.num_group = num_group
|
||||
self.summa_dim = summa_dim
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2D tensor col parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2D_COL
|
||||
|
||||
for i in range(self.num_group):
|
||||
for j in range(self.summa_dim):
|
||||
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim
|
||||
for k in range(self.summa_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_2D(ProcessGroupInitializer):
|
||||
"""
|
||||
Serve as the single entry point to 2D parallel initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
|
||||
|
||||
assert self.tensor_parallel_size == self.summa_dim ** 2, \
|
||||
"2D summa dim should equal to tensor parallel size ^ 0.5"
|
||||
_check_summa_env_var(self.summa_dim)
|
||||
|
||||
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
|
||||
self.row_initializer = Initializer_2D_Row(self.num_group, self.summa_dim, *args, **kwargs)
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2D tensor parallelism's information
|
||||
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
parallel_setting = []
|
||||
parallel_setting.append(self.row_initializer.init_dist_group())
|
||||
parallel_setting.append(self.col_initializer.init_dist_group())
|
||||
return parallel_setting
|
255
colossalai/context/process_group_initializer/initializer_2p5d.py
Normal file
255
colossalai/context/process_group_initializer/initializer_2p5d.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
|
||||
from colossalai.context import Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
def _check_tesseract_env_var(tesseract_dim: int,
|
||||
tesseract_dep: int):
|
||||
# check environment variable for TESSERACT
|
||||
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None)
|
||||
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None)
|
||||
|
||||
if env_tesseract_dim and env_tesseract_dep:
|
||||
assert int(env_tesseract_dim) == tesseract_dim, \
|
||||
'TESSERACT_DIM has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
assert int(env_tesseract_dep) == tesseract_dep, \
|
||||
'TESSERACT_DEP has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[TESSERACT_DIM] = str(tesseract_dim)
|
||||
os.environ[TESSERACT_DEP] = str(tesseract_dep)
|
||||
|
||||
|
||||
# i row j col k dep
|
||||
class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
||||
'''2p5d tensor parallel initialization among rows.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_ROW, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2p5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2p5D tensor row parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_ROW
|
||||
|
||||
for h in range(self.num_group):
|
||||
for j in range(self.tesseract_dim):
|
||||
for k in range(self.tesseract_dep):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||
'''2p5d tensor parallel initialization among cols.
|
||||
'''
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_Col, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2p5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2p5D tensor col parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_COL
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.tesseract_dim):
|
||||
for k in range(self.tesseract_dep):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||
'''2p5D tensor parallel initialization among depths.
|
||||
'''
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_Dep, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2p5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2p5D tensor depth parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_DEP
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.tesseract_dim):
|
||||
for j in range(self.tesseract_dim):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
# i row j col k dep
|
||||
class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
||||
'''2p5d tensor parallel initialization among cols times dep.
|
||||
'''
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_XZ, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2p5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 2p5D tensor colXdepth parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_XZ
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.tesseract_dim):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
|
||||
range(self.tesseract_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_2p5D(ProcessGroupInitializer):
|
||||
"""
|
||||
Serve as the single entry point to Tesseract parallel initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
config: Config,
|
||||
data_parallel_size: int,
|
||||
pipeline_parlalel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
depth: int
|
||||
):
|
||||
args = (rank, world_size, config, data_parallel_size, pipeline_parlalel_size, tensor_parallel_size)
|
||||
super().__init__(*args)
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth))
|
||||
self.tesseract_dep = depth
|
||||
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5"
|
||||
_check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep)
|
||||
|
||||
self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args)
|
||||
self.row_initializer = Initializer_2p5D_ROW(self.tesseract_dim, self.tesseract_dep, *args)
|
||||
self.dep_initializer = Initializer_2p5D_Dep(self.tesseract_dim, self.tesseract_dep, *args)
|
||||
self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args)
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: Whole 2p5D tensor parallelism's information
|
||||
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
parallel_setting = []
|
||||
parallel_setting.append(self.col_initializer.init_dist_group())
|
||||
parallel_setting.append(self.row_initializer.init_dist_group())
|
||||
parallel_setting.append(self.dep_initializer.init_dist_group())
|
||||
parallel_setting.append(self.xz_initializer.init_dist_group())
|
||||
return parallel_setting
|
172
colossalai/context/process_group_initializer/initializer_3d.py
Normal file
172
colossalai/context/process_group_initializer/initializer_3d.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import DEPTH_3D
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from ..parallel_mode import ParallelMode
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
|
||||
def _check_depth_env_var(depth):
|
||||
# check environment variable for SUMMA
|
||||
env_depth = os.environ.get(DEPTH_3D, None)
|
||||
|
||||
if env_depth:
|
||||
assert int(env_depth) == depth, \
|
||||
'SUMMA_DIM has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[DEPTH_3D] = str(depth)
|
||||
|
||||
|
||||
class Initializer_3D_Input(ProcessGroupInitializer):
|
||||
'''2D tensor parallel initialization among input.
|
||||
'''
|
||||
def __init__(self, num_group: int, depth: int, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = num_group
|
||||
self.depth = depth
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 3D tensor parallelism's information among input
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.depth):
|
||||
for k in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth *
|
||||
(j + self.depth * k) for j in range(self.depth)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_3D_Weight(ProcessGroupInitializer):
|
||||
'''3D tensor parallel initialization among weight.
|
||||
'''
|
||||
|
||||
def __init__(self, num_group: int, depth: int, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = num_group
|
||||
self.depth = depth
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 3D tensor parallelism's information among weight
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
|
||||
for h in range(self.num_group):
|
||||
for k in range(self.depth):
|
||||
for j in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth *
|
||||
(j + self.depth * k) for i in range(self.depth)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_3D_Output(ProcessGroupInitializer):
|
||||
'''2D tensor parallel initialization among weight.
|
||||
'''
|
||||
|
||||
def __init__(self, num_group: int, depth: int, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = num_group
|
||||
self.depth = depth
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 3D tensor parallelism's information among output
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.depth):
|
||||
for j in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth *
|
||||
(j + self.depth * k) for k in range(self.depth)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_3D(ProcessGroupInitializer):
|
||||
'''Serve as the single entry point to 3D parallel initialization.
|
||||
'''
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3))
|
||||
assert self.tensor_parallel_size == self.depth ** 3, \
|
||||
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
|
||||
_check_depth_env_var(self.depth)
|
||||
|
||||
self.input_initializer = Initializer_3D_Input(self.num_group,
|
||||
self.depth, *args)
|
||||
self.weight_initializer = Initializer_3D_Weight(
|
||||
self.num_group, self.depth, *args)
|
||||
self.output_initializer = Initializer_3D_Output(
|
||||
self.num_group, self.depth, *args)
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: 3D tensor parallelism's information
|
||||
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
parallel_setting = []
|
||||
parallel_setting.append(self.input_initializer.init_dist_group())
|
||||
parallel_setting.append(self.weight_initializer.init_dist_group())
|
||||
parallel_setting.append(self.output_initializer.init_dist_group())
|
||||
return parallel_setting
|
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from torch import distributed as dist
|
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Data(ProcessGroupInitializer):
|
||||
'''A ProcessGroupInitializer for data parallelism.
|
||||
'''
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_data_parallel_group = self.world_size // self.data_parallel_size
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: data parallelism's information
|
||||
:rtype: tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.DATA
|
||||
|
||||
for i in range(self.num_data_parallel_group):
|
||||
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from torch import distributed as dist
|
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Pipeline(ProcessGroupInitializer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.data_group_size = self.world_size // self.data_parallel_size
|
||||
self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size
|
||||
|
||||
def init_dist_group(self):
|
||||
dist_settings = list()
|
||||
for i in range(self.data_parallel_size):
|
||||
for j in range(self.pipeline_stage_size):
|
||||
pipe_ranks = list(
|
||||
range(i * self.data_group_size + j,
|
||||
(i + 1) * self.data_group_size,
|
||||
self.pipeline_stage_size))
|
||||
pipe_group_size = len(pipe_ranks)
|
||||
pipe_group = dist.new_group(pipe_ranks)
|
||||
|
||||
if self.rank in pipe_ranks:
|
||||
local_rank = pipe_ranks.index(self.rank)
|
||||
group_world_size = pipe_group_size
|
||||
process_group = pipe_group
|
||||
ranks_in_group = pipe_ranks
|
||||
dist_settings.append(
|
||||
tuple((local_rank, group_world_size,
|
||||
process_group, ranks_in_group,
|
||||
ParallelMode.PIPELINE)))
|
||||
|
||||
for k in range(pipe_group_size):
|
||||
first = pipe_ranks[k]
|
||||
second = pipe_ranks[(k + 1) % pipe_group_size]
|
||||
ranks = [first, second]
|
||||
group = dist.new_group(ranks)
|
||||
if self.rank == first:
|
||||
local_rank = 0
|
||||
group_world_size = 2
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
dist_settings.append(
|
||||
tuple((local_rank, group_world_size,
|
||||
process_group, ranks_in_group,
|
||||
ParallelMode.PIPELINE_NEXT)))
|
||||
elif self.rank == second:
|
||||
local_rank = 1
|
||||
group_world_size = 2
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
dist_settings.append(
|
||||
tuple((local_rank, group_world_size,
|
||||
process_group, ranks_in_group,
|
||||
ParallelMode.PIPELINE_PREV)))
|
||||
|
||||
return dist_settings
|
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .initializer_tensor import Initializer_Tensor
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Sequence(ProcessGroupInitializer):
|
||||
'''A ProcessGroupInitializer for sequence parallelism.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# reuse tensor parallel code
|
||||
self._initializer = Initializer_Tensor(*args, **kwargs)
|
||||
|
||||
def init_dist_group(self):
|
||||
local_rank, group_world_size, process_group, ranks_in_group, mode = self._initializer.init_dist_group()
|
||||
|
||||
# change mode to sequence
|
||||
mode = ParallelMode.SEQUENCE
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Tensor(ProcessGroupInitializer):
|
||||
'''A ProcessGroupInitializer for tensor parallelism.
|
||||
'''
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size
|
||||
|
||||
def init_dist_group(self):
|
||||
'''Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: tensor parallelism's information
|
||||
:rtype: tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
'''
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.TENSOR
|
||||
|
||||
for i in range(self.num_tensor_parallel_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
|
||||
class ProcessGroupInitializer(ABC):
|
||||
'''An object, knowing the parallelism configuration, that initializes parallel groups.
|
||||
'''
|
||||
def __init__(self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
config: Config,
|
||||
data_parallel_size: int,
|
||||
pipeline_parlalel_size: int,
|
||||
tensor_parallel_size: int
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.config = config
|
||||
self.pipeline_parallel_size = pipeline_parlalel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def init_dist_group(self):
|
||||
pass
|
8
colossalai/context/random/__init__.py
Normal file
8
colossalai/context/random/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from ._helper import (seed, set_mode, with_seed, add_seed,
|
||||
get_seeds, get_states, get_current_mode,
|
||||
set_seed_states, sync_states)
|
||||
|
||||
__all__ = [
|
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
|
||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states'
|
||||
]
|
144
colossalai/context/random/_helper.py
Normal file
144
colossalai/context/random/_helper.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.cuda
|
||||
from torch import Tensor
|
||||
|
||||
from .seed_manager import SeedManager
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
_SEED_MANAGER = SeedManager()
|
||||
|
||||
|
||||
def get_seeds():
|
||||
"""Returns the seeds of the seed manager.
|
||||
|
||||
:return: The seeds of the seed manager
|
||||
:rtype: dict
|
||||
"""
|
||||
return _SEED_MANAGER.seeds
|
||||
|
||||
|
||||
def get_states(copy=False):
|
||||
"""Returns the seed states of the seed manager.
|
||||
|
||||
:return: The seed states of the seed manager
|
||||
:rtype: dict
|
||||
"""
|
||||
states = _SEED_MANAGER.seed_states
|
||||
|
||||
if copy:
|
||||
new_states = dict()
|
||||
|
||||
for parallel_mode, state in states.items():
|
||||
new_states[parallel_mode] = state.clone()
|
||||
return new_states
|
||||
else:
|
||||
return _SEED_MANAGER.seed_states
|
||||
|
||||
|
||||
def get_current_mode():
|
||||
"""Returns the current mode of the seed manager.
|
||||
|
||||
:return: The current mode of the seed manager.
|
||||
:rtype: :class:`torch.ByteTensor`
|
||||
"""
|
||||
return _SEED_MANAGER.current_mode
|
||||
|
||||
|
||||
def add_seed(parallel_mode: ParallelMode, seed: int):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param seed: The seed to be added
|
||||
:type seed: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||
"""
|
||||
_SEED_MANAGER.add_seed(parallel_mode, seed)
|
||||
|
||||
|
||||
def set_mode(parallel_mode: ParallelMode):
|
||||
"""Sets the current mode of the seed manager.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
_SEED_MANAGER.set_mode(parallel_mode)
|
||||
|
||||
|
||||
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
|
||||
"""Sets the state of the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param state: the state to be set
|
||||
:type state: :class:`torch.Tensor`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
|
||||
"""
|
||||
_SEED_MANAGER.set_state(parallel_mode, state)
|
||||
|
||||
|
||||
def sync_states():
|
||||
current_mode = get_current_mode()
|
||||
current_states = torch.cuda.get_rng_state()
|
||||
set_seed_states(current_mode, current_states)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seed(parallel_mode: ParallelMode):
|
||||
""" A context for seed switch
|
||||
|
||||
Examples::
|
||||
|
||||
with seed(ParallelMode.DATA):
|
||||
output = F.dropout(input)
|
||||
|
||||
"""
|
||||
try:
|
||||
# set to new mode
|
||||
current_mode = _SEED_MANAGER.current_mode
|
||||
yield _SEED_MANAGER.set_mode(parallel_mode)
|
||||
finally:
|
||||
# recover
|
||||
_SEED_MANAGER.set_mode(current_mode)
|
||||
|
||||
|
||||
def with_seed(func, parallel_mode: ParallelMode):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
|
||||
Examples::
|
||||
|
||||
# use with decorator
|
||||
@with_seed(ParallelMode.DATA)
|
||||
def forward(input):
|
||||
return F.dropout(input)
|
||||
out = forward(input)
|
||||
# OR use it inline
|
||||
def forward(input):
|
||||
return F.dropout(input)
|
||||
wrapper_forward = with_seed(forward, ParallelMode.DATA)
|
||||
out = wrapped_forward(input)
|
||||
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# switch mode
|
||||
current_mode = _SEED_MANAGER.current_mode
|
||||
_SEED_MANAGER.set_mode(parallel_mode)
|
||||
|
||||
# exec func
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# recover state
|
||||
_SEED_MANAGER.set_mode(current_mode)
|
||||
|
||||
return out
|
||||
|
||||
return wrapper
|
74
colossalai/context/random/seed_manager.py
Normal file
74
colossalai/context/random/seed_manager.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""This class is a manager of all random seeds involved in the system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._current_mode = None
|
||||
self._seeds = dict()
|
||||
self._seed_states = dict()
|
||||
|
||||
@property
|
||||
def current_mode(self):
|
||||
return self._current_mode
|
||||
|
||||
@property
|
||||
def seeds(self):
|
||||
return self._seeds
|
||||
|
||||
@property
|
||||
def seed_states(self):
|
||||
return self._seed_states
|
||||
|
||||
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
|
||||
"""Sets the state of the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param state: the state to be set
|
||||
:type state: :class:`torch.Tensor`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
|
||||
"""
|
||||
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
|
||||
self._seed_states[parallel_mode] = state
|
||||
|
||||
def set_mode(self, parallel_mode: ParallelMode):
|
||||
"""Sets the current mode of the seed manager.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
if self.current_mode:
|
||||
# save the current state for current mode
|
||||
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
|
||||
|
||||
# set the new state for new mode
|
||||
self._current_mode = parallel_mode
|
||||
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
|
||||
|
||||
def add_seed(self, parallel_mode: ParallelMode, seed: int):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param seed: The seed to be added
|
||||
:type seed: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||
"""
|
||||
assert isinstance(
|
||||
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.manual_seed(seed)
|
||||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
||||
self._seeds[parallel_mode] = seed
|
||||
torch.cuda.set_rng_state(current_state)
|
Reference in New Issue
Block a user