[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -0,0 +1,4 @@
from .parallel_context import ParallelContext
from .parallel_mode import ParallelMode
from .process_group_initializer import *
from .random import *

View File

@@ -0,0 +1,578 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import random
import socket
from collections import Counter
from threading import local
from typing import Union
import numpy as np
import torch
import torch.distributed as dist
from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.legacy.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
Note:
The parallel_mode used in this class should be concluded in ``ParallelMode``.
More details about ``ParallelMode`` could be found in
`parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
def __init__(self):
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
self._world_sizes = dict()
self._groups = dict()
self._cpu_groups = dict()
self._ranks_in_group = dict()
# load config from file
self._config = None
# default 3D parallel args, will be overwritten during process group initialization
self.world_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
# logging
self._verbose = False
self._logger = get_dist_logger()
@property
def config(self):
return self._config
@property
def verbose(self):
return self._verbose
@verbose.setter
def verbose(self, verbose_: bool):
self._verbose = verbose_
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
Args:
config (dict or str): Either a dict containing the configuration information or the filename
of a file containing the configuration information.
Raises:
TypeError: Raises a TypeError if `config` is neither a dict nor a str.
"""
if isinstance(config, str):
self._config = Config.from_file(config)
elif isinstance(config, dict):
self._config = Config(config)
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
def detect_num_processes_on_current_node(self):
hostname = socket.gethostname()
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
counter = Counter(hostname_list)
self.num_processes_on_current_node = counter[hostname]
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode), \
f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}'
def get_global_rank(self):
"""Returns the global rank of the current device.
Returns:
int: The global rank of the current device
"""
return self._global_ranks[ParallelMode.GLOBAL]
def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the global rank of the current device for `parallel_mode` to the context.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank
def get_local_rank(self, parallel_mode: ParallelMode):
"""Returns the local rank of the current device.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The local rank of the current device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode]
def _add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank
def get_next_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the next device.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The global rank of the next device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
# get rank and world size
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
ranks_in_group = self.get_ranks_in_group(parallel_mode)
return ranks_in_group[(local_rank + 1) % world_size]
def get_prev_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the previous device.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The global rank of the previous device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
# get rank and world size
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
ranks_in_group = self.get_ranks_in_group(parallel_mode)
return ranks_in_group[(local_rank - 1) % world_size]
def is_first_rank(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
bool: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
"""
rank = self.get_local_rank(parallel_mode)
return rank == 0
def is_last_rank(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
bool: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
"""
rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
return rank == world_size - 1
def is_pipeline_first_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
return False
return self.is_first_rank(ParallelMode.PIPELINE)
def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size \
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False
return self.is_last_rank(ParallelMode.PIPELINE)
def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The world size for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode]
def _add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode corresponding to the process group
world_size (int): The world size to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size
def get_group(self, parallel_mode: ParallelMode):
"""Returns the group of the current device for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
"""
self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode]
def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
group (torch.distributed.ProcessGroup): The group to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
def get_cpu_group(self, parallel_mode: ParallelMode):
"""Returns the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.legacy.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
self._check_parallel_mode(parallel_mode)
return self._cpu_groups[parallel_mode]
def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.legacy.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._cpu_groups[parallel_mode] = group
def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The rank of the current device for `parallel_mode` in the group.
"""
self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode]
def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
ranks (list): List of ranks to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
"""Initializes the global distributed environment
Args:
rank (int): rank for the default process group.
world_size (int): world size of the default process group.
backend (str): backend for ``torch.distributed``
host (str): the master address for distributed training.
port (str): the master port for distributed training
"""
# initialize the default process group
init_method = f'tcp://[{host}]:{port}'
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# None will give the default global process group for pytorch dist operations
ranks = list(range(world_size))
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
self._add_local_rank(mode, local_rank)
self._add_world_size(mode, world_size)
self._add_group(mode, process_group)
self._add_cpu_group(mode, cpu_group)
self._add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self):
"""Checks sanity of the parallel context.
Raises:
AssertionError: Raises an AssertionError if the world size does not equal to the product
of data parallel size, pipeline parallel size and tensor parallel size.
"""
dps = self.data_parallel_size
pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size
ws = self.world_size
assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data" \
f" parallel size ({dps}) * pipeline parallel size " \
f"({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(self, attr_name, ele)
elif isinstance(ele, dict):
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
def init_parallel_groups(self):
"""Initializes the parallel groups.
Raises:
AssertionError: Raises an AssertionError if the field parallel is not present in the config file.
"""
# get rank and world size
rank = self.get_global_rank()
world_size = self.get_world_size(ParallelMode.GLOBAL)
self.world_size = world_size
# set parallel size as attributes for global context
parallel_config = self.config.get('parallel', None)
if parallel_config is not None:
self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size')
self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size')
# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# get the tensor parallel mode and check
tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in \
parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, \
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode
self.check_sanity()
pg_init = []
# LSG: init data parallel process group for compatibility with other parallel module such as zero
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
# LSG: init model parallel process group for compatibility with amp and clip grad
pg_init.append(dict(type=INITIALIZER_MAPPING['model']))
if self.pipeline_parallel_size > 1:
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
# init specific tensor parallel group
if tensor_parallel_mode is not None:
tensor_parallel_cfg = parallel_config['tensor'].copy()
# remove duplicate parameters
tensor_parallel_cfg.pop('mode')
tensor_parallel_cfg.pop('size')
# add this config to initialize later
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
# run initialization of different process groups
for initializer_cfg in pg_init:
cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size, **cfg)
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
self._register_dist(*args)
else:
self._register_dist(*parallel_setting)
def is_initialized(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether `parallel_mode` is initialized
in the current system.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Returns:
bool: a boolean value indicating whether `parallel_mode` is initialized in the current system.
"""
return parallel_mode in self._groups
def destroy(self):
"""Destroys the current distributed parallel environment.
"""
for mode, group in self._groups.items():
if mode is not ParallelMode.GLOBAL:
dist.destroy_process_group(group)
# destroy global process group
dist.destroy_process_group()
self._groups.clear()
def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices.
Args:
device_ordinal (int, optional): the device id to be bound to
"""
global_rank = self.get_global_rank()
if device_ordinal is None:
devices_per_node = torch.cuda.device_count()
device_ordinal = global_rank % devices_per_node
torch.cuda.set_device(device_ordinal)
if self._verbose:
self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}')
def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
Args:
seed (int): seed for random states
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
global_rank = self.get_global_rank()
if torch.cuda.is_available():
# create random seed for different parallel modes
# data parallel seed are kept the same
parallel_seed = seed
add_seed(ParallelMode.DATA, parallel_seed)
# model parallel seeds are different across ranks
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
# add seed for data parallel and tensor parallel only
if self.is_initialized(ParallelMode.TENSOR):
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
# 100 is only to increase the diff in seeds between pipeline stages
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
tp_seed = seed + tp_rank_with_offset
add_seed(ParallelMode.TENSOR, tp_seed)
set_mode(ParallelMode.DATA)
seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
if self._verbose:
self._logger.info(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0])
self._logger.info(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
ranks=[0])
def set_virtual_pipeline_parallel_size(self, size):
self.virtual_pipeline_parallel_size = size
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
global_context = ParallelContext()

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from enum import Enum
# parallel modes
class ParallelMode(Enum):
"""This is an enumeration class containing all possible parallel modes.
"""
GLOBAL = 'global'
# common parallel
DATA = 'data'
# model parallel - containing tensor and pipeline parallel groups
# this is added to facilitate amp and grad clipping in hybrid parallel
MODEL = 'model'
# pipeline parallel
PIPELINE = 'pipe'
# containing all ranks in tensor parallel
TENSOR = 'tensor'
# sequence parallel
SEQUENCE = 'sequence'
SEQUENCE_DP = 'sequence_dp'
# 1D Parallel
PARALLEL_1D = '1d'
# 2D parallel
PARALLEL_2D_ROW = '2d_row'
PARALLEL_2D_COL = '2d_col'
# 3D parallel
PARALLEL_3D_INPUT = '3d_input'
PARALLEL_3D_WEIGHT = '3d_weight'
PARALLEL_3D_OUTPUT = '3d_output'
PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight"
PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight"
# 2.5D parallel
PARALLEL_2P5D_ROW = '2p5d_row'
PARALLEL_2P5D_COL = '2p5d_col'
PARALLEL_2P5D_DEP = '2p5d_dep'
PARALLEL_2P5D_XZ = '2p5d_xz'

View File

@@ -0,0 +1,15 @@
from .initializer_1d import Initializer_1D
from .initializer_2d import Initializer_2D
from .initializer_2p5d import Initializer_2p5D
from .initializer_3d import Initializer_3D
from .initializer_data import Initializer_Data
from .initializer_model import Initializer_Model
from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer
__all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D',
'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
]

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer):
"""A ProcessGroupInitializer for 1d tensor parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_group = self.world_size // self.tensor_parallel_size
def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
1D tensor parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
env.parallel_input_1d = False
for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@@ -0,0 +1,155 @@
import math
import torch.distributed as dist
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_summa_env_var(summa_dim):
# check environment variable for SUMMA
env_summa_dim = env.summa_dim
if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \
'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
env.summa_dim = summa_dim
class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows.
Args:
num_group (int): The number of all tensor groups.
summa_dim (int): The dimension of SUMMA.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group, summa_dim, *args, **kwargs):
super(Initializer_2D_Row, self).__init__(*args, **kwargs)
self.num_group = num_group
self.summa_dim = summa_dim
def init_dist_group(self):
"""Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2D tensor row parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2D_ROW
for i in range(self.num_group):
for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_2D_Col(ProcessGroupInitializer):
"""2d tensor parallel initialization among cols.
Args:
num_group (int): The number of all tensor groups.
summa_dim (int): The dimension of SUMMA.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group, summa_dim, *args, **kwargs):
super(Initializer_2D_Col, self).__init__(*args, **kwargs)
self.num_group = num_group
self.summa_dim = summa_dim
def init_dist_group(self):
"""Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2D tensor col parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2D_COL
for i in range(self.num_group):
for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_2D(ProcessGroupInitializer):
"""
Serve as the single entry point to 2D parallel initialization.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert self.tensor_parallel_size == self.summa_dim ** 2, \
"2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
self.row_initializer = Initializer_2D_Row(self.num_group, self.summa_dim, *args, **kwargs)
def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
2D tensor parallelism's information in a list of tuples.
"""
parallel_setting = [self.row_initializer.init_dist_group(), self.col_initializer.init_dist_group()]
return parallel_setting

View File

@@ -0,0 +1,298 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch.distributed as dist
from colossalai.context import Config
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int):
# check global variable for TESSERACT
env_tesseract_dim = env.tesseract_dim
env_tesseract_dep = env.tesseract_dep
if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \
'TESSERACT_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
assert int(env_tesseract_dep) == tesseract_dep, \
'TESSERACT_DEP has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
env.tesseract_dim = tesseract_dim
env.tesseract_dep = tesseract_dep
# i row j col k dep
class Initializer_2p5D_ROW(ProcessGroupInitializer):
"""2.5d tensor parallel initialization among rows.
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
super(Initializer_2p5D_ROW, self).__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor row parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_ROW
for h in range(self.num_group):
for j in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k)
for i in range(self.tesseract_dim)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_2p5D_Col(ProcessGroupInitializer):
"""2.5d tensor parallel initialization among cols.
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
super(Initializer_2p5D_Col, self).__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
def init_dist_group(self):
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor col parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_COL
for h in range(self.num_group):
for i in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k)
for j in range(self.tesseract_dim)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_2p5D_Dep(ProcessGroupInitializer):
"""2.5D tensor parallel initialization among depths.
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
super(Initializer_2p5D_Dep, self).__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
def init_dist_group(self):
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor depth parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_DEP
for h in range(self.num_group):
for i in range(self.tesseract_dim):
for j in range(self.tesseract_dim):
ranks = [
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k)
for k in range(self.tesseract_dep)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
# i row j col k dep
class Initializer_2p5D_XZ(ProcessGroupInitializer):
"""2.5d tensor parallel initialization among cols times dep.
Args:
tesseract_dim (int): The dimension of tesseract.
tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
super(Initializer_2p5D_XZ, self).__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
def init_dist_group(self):
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor colXdepth parallelism's information in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_XZ
for h in range(self.num_group):
for i in range(self.tesseract_dim):
ranks = [
h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k)
for k in range(self.tesseract_dep)
for j in range(self.tesseract_dim)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_2p5D(ProcessGroupInitializer):
"""
Serve as the single entry point to Tesseract parallel initialization.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
depth (int): The depth of 2.5d parallel.
"""
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int,
tensor_parallel_size: int, depth: int):
args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size)
super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth))
self.tesseract_dep = depth
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5"
_check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep)
self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args)
self.row_initializer = Initializer_2p5D_ROW(self.tesseract_dim, self.tesseract_dep, *args)
self.dep_initializer = Initializer_2p5D_Dep(self.tesseract_dim, self.tesseract_dep, *args)
self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args)
def init_dist_group(self):
"""Initialize 2.5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
Whole 2.5D tensor parallelism's information in a list of tuples.
"""
parallel_setting = [
self.col_initializer.init_dist_group(),
self.row_initializer.init_dist_group(),
self.dep_initializer.init_dist_group(),
self.xz_initializer.init_dist_group()
]
return parallel_setting

View File

@@ -0,0 +1,329 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch.distributed as dist
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_depth_env_var(depth):
# check global variable
env_depth = env.depth_3d
if env_depth:
assert int(env_depth) == depth, \
'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
env.depth_3d = depth
class Initializer_3D_Input(ProcessGroupInitializer):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
env.input_group_3d = mode
for h in range(self.num_group):
for i in range(self.depth):
for k in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_Weight(ProcessGroupInitializer):
"""3D tensor parallel initialization among weight.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among weight in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
env.weight_group_3d = mode
for h in range(self.num_group):
for k in range(self.depth):
for j in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_Output(ProcessGroupInitializer):
"""3D tensor parallel initialization among output.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among output in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
env.output_group_3d = mode
for h in range(self.num_group):
for i in range(self.depth):
for j in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_InputxWeight(ProcessGroupInitializer):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT
env.input_x_weight_group_3d = mode
for h in range(self.num_group):
for k in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k)
for j in range(self.depth)
for i in range(self.depth)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT
env.output_x_weight_group_3d = mode
for h in range(self.num_group):
for j in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k)
for k in range(self.depth)
for i in range(self.depth)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_3D(ProcessGroupInitializer):
"""Serve as the single entry point to 3D parallel initialization.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args):
super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3))
assert self.tensor_parallel_size == self.depth ** 3, \
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
_check_depth_env_var(self.depth)
self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args)
self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args)
def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
Whole 3D tensor parallelism's information in a list of tuples.
"""
parallel_setting = [
self.input_initializer.init_dist_group(),
self.weight_initializer.init_dist_group(),
self.output_initializer.init_dist_group(),
self.input_x_weight_initializer.init_dist_group(),
self.output_x_weight_initializer.init_dist_group()
]
return parallel_setting

View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import distributed as dist
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_data_parallel_group = self.world_size // self.data_parallel_size
def init_dist_group(self):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.DATA
for i in range(self.num_data_parallel_group):
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Model(ProcessGroupInitializer):
"""A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
groups).
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size
self.num_group = self.world_size // self.model_parallel_size
def init_dist_group(self):
"""Initialize model parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Model parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.MODEL
for i in range(self.num_group):
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import distributed as dist
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Pipeline(ProcessGroupInitializer):
"""A ProcessGroupInitializer for pipeline parallelism.
Args:
rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_group_size = self.world_size // self.data_parallel_size
self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size
def init_dist_group(self):
"""Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Pipeline parallelism's information in list of tuples.
"""
dist_settings = list()
for i in range(self.data_parallel_size):
for j in range(self.pipeline_stage_size):
pipe_ranks = list(
range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size))
pipe_group_size = len(pipe_ranks)
pipe_group = dist.new_group(pipe_ranks)
group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group
if self.rank in pipe_ranks:
local_rank = pipe_ranks.index(self.rank)
group_world_size = pipe_group_size
process_group = pipe_group
cpu_group = group_cpu
ranks_in_group = pipe_ranks
dist_settings.append(
tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group,
ParallelMode.PIPELINE)))
return dist_settings

View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Sequence_DP(ProcessGroupInitializer):
"""A ProcessGroupInitializer for sequence parallelism all-reduce.
In Sequence Parallelism, each GPU holds the full copy of model weights,
thus, gradient all-reduce occurs across all processes in the same pipeline stage
Args:
rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dp_size = self.world_size // self.pipeline_parallel_size
self.num_group = self.pipeline_parallel_size
def init_dist_group(self):
"""Initialize Sequence Parallel process groups used for gradient all-reduce.
Returns:
Tuple: A tuple (local_rank, group_world_size, process_group, ranks_in_group, mode).
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.SEQUENCE_DP
for i in range(self.num_group):
ranks = [i * self.dp_size + j for j in range(self.dp_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Sequence(ProcessGroupInitializer):
"""A ProcessGroupInitializer for sequence parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# reuse tensor parallel initializer code
self._sequence_initializer = Initializer_Tensor(*args, **kwargs)
self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs)
def init_dist_group(self):
"""Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu.
Sequence parallelism requires 2 process groups. The first is for model forward where several processes
exchange partial query, key and value embedding to compute self attention values. The second is for
all-reduce to synchronize the model parameters.
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Sequence parallelism's information in list of tuples.
"""
parallel_setting = []
local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode = \
self._sequence_initializer.init_dist_group()
# change mode to sequence
mode = ParallelMode.SEQUENCE
parallel_setting.append((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode))
parallel_setting.append(self._sequence_dp_initializer.init_dist_group())
return parallel_setting

View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Tensor(ProcessGroupInitializer):
"""A ProcessGroupInitializer for tensor parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size
def init_dist_group(self):
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Tensor parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.TENSOR
for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from colossalai.context import Config
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int,
tensor_parallel_size: int):
self.rank = rank
self.world_size = world_size
self.data_parallel_size = data_parallel_size
self.config = config
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
super().__init__()
@abstractmethod
def init_dist_group(self):
pass

View File

@@ -0,0 +1,18 @@
from ._helper import (
add_seed,
get_current_mode,
get_seeds,
get_states,
moe_set_seed,
reset_seeds,
seed,
set_mode,
set_seed_states,
sync_states,
with_seed,
)
__all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
'sync_states', 'moe_set_seed', 'reset_seeds'
]

View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import functools
from contextlib import contextmanager
import torch.cuda
from torch import Tensor
from ..parallel_mode import ParallelMode
from .seed_manager import SeedManager
_SEED_MANAGER = SeedManager()
def get_seeds():
"""Returns the seeds of the seed manager.
Returns:
dict: The seeds of the seed manager.
"""
return _SEED_MANAGER.seeds
def get_states(copy=False):
"""Returns the seed states of the seed manager.
Returns:
dict: The seed states of the seed manager.
"""
states = _SEED_MANAGER.seed_states
if copy:
new_states = dict()
for parallel_mode, state in states.items():
new_states[parallel_mode] = state.clone()
return new_states
else:
return _SEED_MANAGER.seed_states
def get_current_mode():
"""Returns the current mode of the seed manager.
Returns:
:class:`torch.ByteTensor`: The current mode of the seed manager.
"""
return _SEED_MANAGER.current_mode
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
def set_mode(parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
_SEED_MANAGER.set_mode(parallel_mode)
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
"""
_SEED_MANAGER.set_state(parallel_mode, state)
def sync_states():
current_mode = get_current_mode()
current_states = torch.cuda.get_rng_state()
set_seed_states(current_mode, current_states)
@contextmanager
def seed(parallel_mode: ParallelMode):
""" A context for seed switch
Examples:
>>> with seed(ParallelMode.DATA):
>>> output = F.dropout(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
try:
# set to new mode
current_mode = _SEED_MANAGER.current_mode
yield _SEED_MANAGER.set_mode(parallel_mode)
finally:
# recover
_SEED_MANAGER.set_mode(current_mode)
def with_seed(func, parallel_mode: ParallelMode):
"""
A function wrapper which executes the function with a specified seed.
Examples:
>>> # use with decorator
>>> @with_seed(ParallelMode.DATA)
>>> def forward(input):
>>> return F.dropout(input)
>>> out = forward(input)
>>> # OR use it inline
>>> def forward(input):
>>> return F.dropout(input)
>>> wrapper_forward = with_seed(forward, ParallelMode.DATA)
>>> out = wrapped_forward(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# switch mode
current_mode = _SEED_MANAGER.current_mode
_SEED_MANAGER.set_mode(parallel_mode)
# exec func
out = func(*args, **kwargs)
# recover state
_SEED_MANAGER.set_mode(current_mode)
return out
return wrapper
def moe_set_seed(seed):
if torch.cuda.is_available():
from colossalai.legacy.core import global_context as gpc
global_rank = gpc.get_global_rank()
diff_seed = seed + global_rank
add_seed(ParallelMode.TENSOR, diff_seed, True)
print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True)
def reset_seeds():
_SEED_MANAGER.reset()

View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import Tensor
from colossalai.legacy.context.parallel_mode import ParallelMode
class SeedManager:
"""This class is a manager of all random seeds involved in the system.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
def __init__(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()
@property
def current_mode(self):
return self._current_mode
@property
def seeds(self):
return self._seeds
@property
def seed_states(self):
return self._seed_states
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
"""
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
self._seed_states[parallel_mode] = state
def set_mode(self, parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
"""
if self.current_mode:
# save the current state for current mode
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
# set the new state for new mode
self._current_mode = parallel_mode
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added.
overwrite (bool, optional): Whether allows to overwrite the seed that has been set already
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode`
or the seed for `parallel_mode` has been added.
"""
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrite is False:
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
elif parallel_mode in self._seed_states:
print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True)
current_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed)
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
self._seeds[parallel_mode] = seed
torch.cuda.set_rng_state(current_state)
def reset(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()