Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -12,8 +12,8 @@ class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be
accessed as attributes.
:param config: The dict object to be wrapped
:type config: dict
Args:
config (dict): The dict object to be wrapped.
"""
def __init__(self, config: dict = None):
@@ -50,12 +50,14 @@ class Config(dict):
def from_file(filename: str):
"""Reads a python file and constructs a corresponding :class:`Config` object.
:param filename: Name of the file to construct the return object
:type filename: str
:raises AssertionError: Raises an AssertionError if the file does not exist, or the file
is not .py file
:return: A :class:`Config` object constructed with information in the file
:rtype: :class:`Config`
Args:
filename (str): Name of the file to construct the return object.
Returns:
:class:`Config`: A :class:`Config` object constructed with information in the file.
Raises:
AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
"""
# check config path

View File

@@ -22,6 +22,10 @@ class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
Note:
The parallel_mode used in this class should be concluded in ``ParallelMode``.
More details about ``ParallelMode`` could be found in
`parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
def __init__(self):
@@ -62,10 +66,12 @@ class ParallelContext(metaclass=SingletonMeta):
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
:param config: Either a dict containing the configuration information or the filename
of a file containing the configuration information
:type config: dict or str
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str
Args:
config (dict or str): Either a dict containing the configuration information or the filename
of a file containing the configuration information.
Raises:
TypeError: Raises a TypeError if `config` is neither a dict nor a str.
"""
if isinstance(config, str):
self._config = Config.from_file(config)
@@ -81,20 +87,21 @@ class ParallelContext(metaclass=SingletonMeta):
def get_global_rank(self):
"""Returns the global rank of the current device.
:return: The global rank of the current device
:rtype: int
Returns:
int: The global rank of the current device
"""
return self._global_ranks[ParallelMode.GLOBAL]
def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the global rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param rank: The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank
@@ -102,12 +109,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_local_rank(self, parallel_mode: ParallelMode):
"""Returns the local rank of the current device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The local rank of the current device for `parallel_mode`
:rtype: int
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
int: The local rank of the current device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode]
@@ -115,12 +125,13 @@ class ParallelContext(metaclass=SingletonMeta):
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param rank: The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank
@@ -128,12 +139,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_next_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the next device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The global rank of the next device for `parallel_mode`
:rtype: int
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
int: The global rank of the next device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
@@ -147,12 +161,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_prev_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the previous device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The global rank of the previous device for `parallel_mode`
:rtype: int
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
int: The global rank of the previous device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
@@ -167,13 +184,16 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`
:rtype: bool
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
bool: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
"""
rank = self.get_local_rank(parallel_mode)
return rank == 0
@@ -182,13 +202,16 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`
:rtype: bool
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
bool: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
"""
rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
@@ -210,12 +233,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The world size for `parallel_mode`
:rtype: int
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
int: The world size for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode]
@@ -223,12 +249,13 @@ class ParallelContext(metaclass=SingletonMeta):
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param world_size: The world size to be added
:type world_size: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
world_size (int): The world size to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size
@@ -236,12 +263,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_group(self, parallel_mode: ParallelMode):
"""Returns the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode]
@@ -249,12 +279,13 @@ class ParallelContext(metaclass=SingletonMeta):
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
group (torch.distributed.ProcessGroup): The group to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
@@ -262,12 +293,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: the rank of the current device for `parallel_mode` in the group
:rtype: int
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
Returns:
int: The rank of the current device for `parallel_mode` in the group.
"""
self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode]
@@ -275,28 +309,26 @@ class ParallelContext(metaclass=SingletonMeta):
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param ranks: List of ranks to be added
:type ranks: list
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
ranks (list): List of ranks to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
Args:
rank (int): rank for the default process group.
world_size (int): world size of the default process group.
backend (str): backend for ``torch.distributed``
host (str): the master address for distributed training.
port (str): the master port for distributed training
"""
# initialize the default process group
init_method = f'tcp://{host}:{port}'
@@ -315,8 +347,9 @@ class ParallelContext(metaclass=SingletonMeta):
def check_sanity(self):
"""Checks sanity of the parallel context.
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product
of data paralle size, pipeline parallel size and tensor parallel size
Raises:
AssertionError: Raises an AssertionError if the world size does not equal to the product
of data parallel size, pipeline parallel size and tensor parallel size.
"""
dps = self.data_parallel_size
pps = self.pipeline_parallel_size
@@ -341,7 +374,8 @@ class ParallelContext(metaclass=SingletonMeta):
def init_parallel_groups(self):
"""Initializes the parallel groups.
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file
Raises:
AssertionError: Raises an AssertionError if the field parallel is not present in the config file.
"""
# get rank and world size
@@ -411,11 +445,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns a boolean value indicating whether `parallel_mode` is initialized
in the current system.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether `parallel_mode` is initialized
in the current system
:rtype: bool
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Returns:
bool: a boolean value indicating whether `parallel_mode` is initialized in the current system.
"""
return parallel_mode in self._groups
@@ -432,8 +466,8 @@ class ParallelContext(metaclass=SingletonMeta):
def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices.
:param device_ordinal: the device id to be bound to
:type device_ordinal: int, optional
Args:
device_ordinal (int, optional): the device id to be bound to
"""
global_rank = self.get_global_rank()
if device_ordinal is None:
@@ -447,8 +481,8 @@ class ParallelContext(metaclass=SingletonMeta):
def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
:param seed: seed for random states
:type seed: int
Args:
seed (int): seed for random states
"""
random.seed(seed)
np.random.seed(seed)

View File

@@ -11,8 +11,16 @@ from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer):
'''A ProcessGroupInitializer for 1d tensor parallelism.
'''
"""A ProcessGroupInitializer for 1d tensor parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -20,8 +28,10 @@ class Initializer_1D(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
1D tensor parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None

View File

@@ -22,12 +22,16 @@ def _check_summa_env_var(summa_dim):
class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows.
:param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA
:param args: Args used to initialize base class
:param kwargs: Kwargs used to initialize base class
:type num_group: int
:type summa_dim: int
Args:
num_group (int): The number of all tensor groups.
summa_dim (int): The dimension of SUMMA.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group, summa_dim, *args, **kwargs):
@@ -37,9 +41,9 @@ class Initializer_2D_Row(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor row parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2D tensor row parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -64,13 +68,15 @@ class Initializer_2D_Row(ProcessGroupInitializer):
class Initializer_2D_Col(ProcessGroupInitializer):
"""2d tensor parallel initialization among cols.
:param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA
:param args: Args used to initialize base class
:param kwargs: Kwargs used to initialize base class
:type num_group: int
:type summa_dim: int
Args:
num_group (int): The number of all tensor groups.
summa_dim (int): The dimension of SUMMA.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group, summa_dim, *args, **kwargs):
@@ -81,8 +87,9 @@ class Initializer_2D_Col(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor col parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2D tensor col parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -109,8 +116,13 @@ class Initializer_2D(ProcessGroupInitializer):
"""
Serve as the single entry point to 2D parallel initialization.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
@@ -127,8 +139,10 @@ class Initializer_2D(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
2D tensor parallelism's information in a list of tuples.
"""
parallel_setting = [self.row_initializer.init_dist_group(), self.col_initializer.init_dist_group()]
return parallel_setting

View File

@@ -31,14 +31,17 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int):
# i row j col k dep
class Initializer_2p5D_ROW(ProcessGroupInitializer):
"""2p5d tensor parallel initialization among rows.
"""2.5d tensor parallel initialization among rows.
:param tesseract_dim: The dimension of tesseract
:param tesseract_dep: The dimension of depth
:param args: Args used to initialize base class
:type tesseract_dim: int
:type tesseract_dep: int
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
@@ -50,10 +53,11 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2p5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor row parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor row parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -80,14 +84,17 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
class Initializer_2p5D_Col(ProcessGroupInitializer):
"""2p5d tensor parallel initialization among cols.
"""2.5d tensor parallel initialization among cols.
:param tesseract_dim: The dimension of tesseract
:param tesseract_dep: The dimension of depth
:param args: Args used to initialize base class
:type tesseract_dim: int
:type tesseract_dep: int
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
@@ -99,10 +106,11 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2p5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor col parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor col parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -129,14 +137,17 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
class Initializer_2p5D_Dep(ProcessGroupInitializer):
"""2p5D tensor parallel initialization among depths.
"""2.5D tensor parallel initialization among depths.
:param tesseract_dim: The dimension of tesseract
:param tesseract_dep: The dimension of depth
:param args: Args used to initialize base class
:type tesseract_dim: int
:type tesseract_dep: int
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
@@ -148,10 +159,11 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2p5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor depth parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor depth parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -179,14 +191,17 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
# i row j col k dep
class Initializer_2p5D_XZ(ProcessGroupInitializer):
"""2p5d tensor parallel initialization among cols times dep.
"""2.5d tensor parallel initialization among cols times dep.
:param tesseract_dim: The dimension of tesseract
:param tesseract_dep: The dimension of depth
:param args: Args used to initialize base class
:type tesseract_dim: int
:type tesseract_dep: int
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
@@ -198,10 +213,11 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2p5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor colXdepth parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor colXdepth parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -232,20 +248,14 @@ class Initializer_2p5D(ProcessGroupInitializer):
"""
Serve as the single entry point to Tesseract parallel initialization.
:param rank: The rank of current process
:param world_size: Size of whole communication world
:param config: Running configuration
:param data_parallel_size: Size of data parallel
:param pipeline_parallel_size: Size of pipeline parallel
:param tensor_parallel_size: Size of tensor parallel
:param depth: The depth of 2p5d parallel
:type rank: int
:type world_size: int
:type config: Config
:type data_parallel_size: int
:type pipeline_parallel_size: int
:type tensor_parallel_size: int
:type depth: int
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
depth (int): The depth of 2.5d parallel.
"""
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int,
@@ -266,9 +276,11 @@ class Initializer_2p5D(ProcessGroupInitializer):
self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args)
def init_dist_group(self):
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""Initialize 2.5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
Whole 2.5D tensor parallelism's information in a list of tuples.
"""
parallel_setting = [
self.col_initializer.init_dist_group(),

View File

@@ -26,12 +26,15 @@ def _check_depth_env_var(depth):
class Initializer_3D_Input(ProcessGroupInitializer):
"""3D tensor parallel initialization among input.
:param num_group: The number of all tensor groups
:param depth: Depth of 3D parallelism
:param args: Args used in base class
:type num_group: int
:type depth: int
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
@@ -42,8 +45,9 @@ class Initializer_3D_Input(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among input
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -70,12 +74,15 @@ class Initializer_3D_Input(ProcessGroupInitializer):
class Initializer_3D_Weight(ProcessGroupInitializer):
"""3D tensor parallel initialization among weight.
:param num_group: The number of all tensor groups
:param depth: Depth of 3D parallelism
:param args: Args used in base class
:type num_group: int
:type depth: int
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
@@ -86,8 +93,9 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among weight
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among weight in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -114,12 +122,15 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
class Initializer_3D_Output(ProcessGroupInitializer):
"""3D tensor parallel initialization among output.
:param num_group: The number of all tensor groups
:param depth: Depth of 3D parallelism
:param args: Args used in base class
:type num_group: int
:type depth: int
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
@@ -130,8 +141,9 @@ class Initializer_3D_Output(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among output
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among output in a tuple.
"""
local_rank = None
ranks_in_group = None
@@ -158,7 +170,14 @@ class Initializer_3D_Output(ProcessGroupInitializer):
@DIST_GROUP_INITIALIZER.register_module
class Initializer_3D(ProcessGroupInitializer):
"""Serve as the single entry point to 3D parallel initialization.
:param args: Args used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args):
@@ -175,8 +194,10 @@ class Initializer_3D(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
Whole 3D tensor parallelism's information in a list of tuples.
"""
parallel_setting = [
self.input_initializer.init_dist_group(),

View File

@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode
class Initializer_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -22,8 +27,9 @@ class Initializer_Data(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
:return: Data parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None

View File

@@ -12,8 +12,13 @@ class Initializer_Model(ProcessGroupInitializer):
"""A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
groups).
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
@@ -24,8 +29,9 @@ class Initializer_Model(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize model parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Model parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None

View File

@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode
class Initializer_Pipeline(ProcessGroupInitializer):
"""A ProcessGroupInitializer for pipeline parallelism.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -23,8 +28,9 @@ class Initializer_Pipeline(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
:return: Pipeline parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Pipeline parallelism's information in list of tuples.
"""
dist_settings = list()
for i in range(self.data_parallel_size):

View File

@@ -15,8 +15,13 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
In Sequence Parallelism, each GPU holds the full copy of model weights,
thus, gradient all-reduce occurs across all processes in the same pipeline stage
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
"""
def __init__(self, *args, **kwargs):
@@ -27,8 +32,8 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize Sequence Parallel process groups used for gradient all-reduce.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple
Returns:
Tuple: A tuple (local_rank, group_world_size, process_group, ranks_in_group, mode).
"""
local_rank = None
ranks_in_group = None
@@ -52,8 +57,13 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
class Initializer_Sequence(ProcessGroupInitializer):
"""A ProcessGroupInitializer for sequence parallelism.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self,
*args, **kwargs):
@@ -66,11 +76,12 @@ class Initializer_Sequence(ProcessGroupInitializer):
"""Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu.
Sequence parallelism requires 2 process groups. The first is for model forward where several processes
exchange paritial query, key and value embedding to compute self attention values. The second is for
exchange partial query, key and value embedding to compute self attention values. The second is for
all-reduce to synchronize the model parameters.
:return: Sequence parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Sequence parallelism's information in list of tuples.
"""
parallel_setting = []

View File

@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode
class Initializer_Tensor(ProcessGroupInitializer):
"""A ProcessGroupInitializer for tensor parallelism.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -22,8 +27,9 @@ class Initializer_Tensor(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: Tensor parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Tensor parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None

View File

@@ -9,19 +9,13 @@ from colossalai.context import Config
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
:param rank: The rank of current process
:param world_size: Size of whole communication world
:param config: Running configuration
:param data_parallel_size: Size of data parallel
:param pipeline_parallel_size: Size of pipeline parallel
:param tensor_parallel_size: Size of tensor parallel
:type rank: int
:type world_size: int
:type config: Config
:type data_parallel_size: int
:type pipeline_parallel_size: int
:type tensor_parallel_size: int
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self,
rank: int,

View File

@@ -16,8 +16,8 @@ _SEED_MANAGER = SeedManager()
def get_seeds():
"""Returns the seeds of the seed manager.
:return: The seeds of the seed manager
:rtype: dict
Returns:
dict: The seeds of the seed manager.
"""
return _SEED_MANAGER.seeds
@@ -25,8 +25,8 @@ def get_seeds():
def get_states(copy=False):
"""Returns the seed states of the seed manager.
:return: The seed states of the seed manager
:rtype: dict
Returns:
dict: The seed states of the seed manager.
"""
states = _SEED_MANAGER.seed_states
@@ -43,8 +43,8 @@ def get_states(copy=False):
def get_current_mode():
"""Returns the current mode of the seed manager.
:return: The current mode of the seed manager.
:rtype: :class:`torch.ByteTensor`
Returns:
:class:`torch.ByteTensor`: The current mode of the seed manager.
"""
return _SEED_MANAGER.current_mode
@@ -52,12 +52,16 @@ def get_current_mode():
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param seed: The seed to be added
:type seed: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
@@ -65,8 +69,12 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
def set_mode(parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
_SEED_MANAGER.set_mode(parallel_mode)
@@ -74,11 +82,12 @@ def set_mode(parallel_mode: ParallelMode):
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param state: the state to be set
:type state: :class:`torch.Tensor`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
"""
_SEED_MANAGER.set_state(parallel_mode, state)
@@ -98,6 +107,9 @@ def seed(parallel_mode: ParallelMode):
with seed(ParallelMode.DATA):
output = F.dropout(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
try:
# set to new mode
@@ -125,6 +137,9 @@ def with_seed(func, parallel_mode: ParallelMode):
wrapper_forward = with_seed(forward, ParallelMode.DATA)
out = wrapped_forward(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
@functools.wraps(func)

View File

@@ -9,6 +9,10 @@ from colossalai.context.parallel_mode import ParallelMode
class SeedManager:
"""This class is a manager of all random seeds involved in the system.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
def __init__(self):
@@ -30,12 +34,12 @@ class SeedManager:
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param state: the state to be set
:type state: :class:`torch.Tensor`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
"""
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
self._seed_states[parallel_mode] = state
@@ -43,8 +47,8 @@ class SeedManager:
def set_mode(self, parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
"""
if self.current_mode:
# save the current state for current mode
@@ -57,14 +61,14 @@ class SeedManager:
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param seed: The seed to be added
:type seed: int
:param overwrtie: Whether allows to overwrite the seed that has been set already
:type overwrtie: bool, optional
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added.
overwrtie (bool, optional): Whether allows to overwrite the seed that has been set already
Raises
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added.
"""
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrtie is False: