mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
Migrated project
This commit is contained in:
454
colossalai/context/parallel_context.py
Normal file
454
colossalai/context/parallel_context.py
Normal file
@@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from ._utils import set_parallel_size
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
:param args: The distributed arguments in the system
|
||||
:type args: dict
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
self._world_sizes = dict()
|
||||
self._groups = dict()
|
||||
self._ranks_in_group = dict()
|
||||
|
||||
# load config from file
|
||||
self._dist_args = args
|
||||
self._config = None
|
||||
|
||||
# default 3D parallel args, will be overwritten during process group intialization
|
||||
self.world_size = 1
|
||||
self.data_parallel_size = 1
|
||||
self.pipeline_parallel_size = 1
|
||||
self.tensor_parallel_size = 1
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
def load_config(self, config: Union[dict, str]):
|
||||
"""Loads the configuration from either a dict or a file.
|
||||
|
||||
:param config: Either a dict containing the configuration information or the filename
|
||||
of a file containing the configuration information
|
||||
:type config: dict or str
|
||||
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
self._config = Config.from_file(config)
|
||||
elif isinstance(config, dict):
|
||||
self._config = Config(config)
|
||||
else:
|
||||
raise TypeError("Invalid type for config, only dictionary or string is supported")
|
||||
|
||||
def set_dist_args(self, args):
|
||||
"""Sets the distributed arguments.
|
||||
|
||||
:param args: The distributed arguments in the system
|
||||
:type args: dict
|
||||
"""
|
||||
self._dist_args = args
|
||||
|
||||
@staticmethod
|
||||
def _check_parallel_mode(parallel_mode: ParallelMode):
|
||||
assert isinstance(parallel_mode, ParallelMode)
|
||||
|
||||
def get_global_rank(self):
|
||||
"""Returns the global rank of the current device.
|
||||
|
||||
:return: The global rank of the current device
|
||||
:rtype: int
|
||||
"""
|
||||
return self._global_ranks[ParallelMode.GLOBAL]
|
||||
|
||||
def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
|
||||
"""Adds the global rank of the current device for `parallel_mode` to the context.
|
||||
|
||||
:param parallel_mode: The parallel mode for the rank
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param rank: The rank to be added
|
||||
:type rank: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._global_ranks[parallel_mode] = rank
|
||||
|
||||
def get_local_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the local rank of the current device.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The local rank of the current device for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._local_ranks[parallel_mode]
|
||||
|
||||
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
|
||||
"""Adds the local rank of the current device for `parallel_mode` to the context.
|
||||
|
||||
:param parallel_mode: The parallel mode for the rank
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param rank: The rank to be added
|
||||
:type rank: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._local_ranks[parallel_mode] = rank
|
||||
|
||||
def get_next_global_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the global rank of the next device.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The global rank of the next device for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
|
||||
# get rank and world size
|
||||
local_rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
||||
|
||||
return ranks_in_group[(local_rank + 1) % world_size]
|
||||
|
||||
def get_prev_global_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the global rank of the previous device.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The global rank of the previous device for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
|
||||
# get rank and world size
|
||||
local_rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
||||
|
||||
return ranks_in_group[(local_rank - 1) % world_size]
|
||||
|
||||
def is_first_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`
|
||||
:rtype: bool
|
||||
"""
|
||||
rank = self.get_local_rank(parallel_mode)
|
||||
return rank == 0
|
||||
|
||||
def is_last_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether the current device is the last one
|
||||
among its group for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: a boolean value indicating whether the current device is the last one
|
||||
among its group for `parallel_mode`
|
||||
:rtype: bool
|
||||
"""
|
||||
rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
return rank == world_size - 1
|
||||
|
||||
def get_world_size(self, parallel_mode: ParallelMode):
|
||||
"""Returns the world size for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The world size for `parallel_mode`
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._world_sizes[parallel_mode]
|
||||
|
||||
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
|
||||
"""Adds world size for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param world_size: The world size to be added
|
||||
:type world_size: int
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._world_sizes[parallel_mode] = world_size
|
||||
|
||||
def get_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The group of the current device for `parallel_mode`
|
||||
:rtype: torch.distributed.ProcessGroup
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._groups[parallel_mode]
|
||||
|
||||
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
||||
"""Adds the group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param group: The group to be added
|
||||
:type group: torch.distributed.ProcessGroup
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._groups[parallel_mode] = group
|
||||
|
||||
def get_ranks_in_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the rank of the current device for `parallel_mode` in the group.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: the rank of the current device for `parallel_mode` in the group
|
||||
:rtype: int
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._ranks_in_group[parallel_mode]
|
||||
|
||||
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
|
||||
"""Adds the ranks of the current device for `parallel_mode` in the group.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param ranks: List of ranks to be added
|
||||
:type ranks: list
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._ranks_in_group[parallel_mode] = ranks
|
||||
|
||||
def init_global_dist(self, addr=None, port=None):
|
||||
"""Initializes the global distributed environment.
|
||||
|
||||
:param addr: The IP address of the current device
|
||||
:type addr: str, optional
|
||||
:param port: The port to be used in the system of the current device
|
||||
:type port: int, optional
|
||||
"""
|
||||
# get config
|
||||
rank = self._dist_args.local_rank
|
||||
world_size = self._dist_args.world_size
|
||||
# default env config, overwrite by exporting
|
||||
# them in your bash script
|
||||
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
|
||||
port = os.getenv('MASTER_PORT', '8008') if port is None else port
|
||||
init_method = f'tcp://{addr}:{port}'
|
||||
|
||||
dist.init_process_group(backend=self._dist_args.backend,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
self._register_dist(rank, world_size, None,
|
||||
list(range(world_size)), ParallelMode.GLOBAL)
|
||||
self._global_ranks[ParallelMode.GLOBAL] = rank
|
||||
|
||||
def _register_dist(self, local_rank, world_size,
|
||||
process_group, ranks_in_group, mode):
|
||||
self.add_local_rank(mode, local_rank)
|
||||
self.add_world_size(mode, world_size)
|
||||
self.add_group(mode, process_group)
|
||||
self.add_ranks_in_group(mode, ranks_in_group)
|
||||
|
||||
def check_sanity(self):
|
||||
"""Checks sanity of the parallel context.
|
||||
|
||||
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product
|
||||
of data paralle size, pipeline parallel size and tensor parallel size
|
||||
"""
|
||||
dps = self.data_parallel_size
|
||||
pps = self.pipeline_parallel_size
|
||||
tps = self.tensor_parallel_size
|
||||
ws = self.world_size
|
||||
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
|
||||
|
||||
def init_parallel_groups(self):
|
||||
"""Initializes the parallel groups.
|
||||
|
||||
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file
|
||||
"""
|
||||
|
||||
# get rank and world size
|
||||
rank = self.get_global_rank()
|
||||
world_size = self.get_world_size(ParallelMode.GLOBAL)
|
||||
self.world_size = world_size
|
||||
|
||||
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
|
||||
|
||||
# set parallel size as attributes for global context
|
||||
parallel_config = self.config.parallel
|
||||
set_parallel_size(self, parallel_config, 'pipeline',
|
||||
'pipeline_parallel_size')
|
||||
set_parallel_size(self, parallel_config, 'tensor',
|
||||
'tensor_parallel_size')
|
||||
|
||||
# the user should not set the data parallel size manually
|
||||
# instead, it should be calculated based on other parallel config
|
||||
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
|
||||
|
||||
# get the tensor parallel mode and check
|
||||
tensor_parallel_mode = parallel_config['tensor'].get('mode', None)
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
self.check_sanity()
|
||||
|
||||
pg_init = []
|
||||
# LSG: init data parallel process group for compatibility with other parallel module such as zero
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
|
||||
|
||||
if self.pipeline_parallel_size > 1:
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
|
||||
|
||||
# init specific tensor parallel group
|
||||
if tensor_parallel_mode is not None:
|
||||
tensor_parallel_cfg = parallel_config['tensor'].copy()
|
||||
|
||||
# remove duplicate parameters
|
||||
tensor_parallel_cfg.pop('mode')
|
||||
tensor_parallel_cfg.pop('size')
|
||||
|
||||
# add this config to initialize later
|
||||
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
|
||||
|
||||
# run initialization of different process groups
|
||||
for initializer_cfg in pg_init:
|
||||
cfg = initializer_cfg.copy()
|
||||
initializer_type = cfg.pop('type')
|
||||
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
|
||||
rank, world_size, self.config,
|
||||
self.data_parallel_size,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
**cfg)
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
for args in parallel_setting:
|
||||
self._register_dist(*args)
|
||||
else:
|
||||
self._register_dist(*parallel_setting)
|
||||
|
||||
def is_initialized(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether `parallel_mode` is initialized
|
||||
in the current system.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:return: a boolean value indicating whether `parallel_mode` is initialized
|
||||
in the current system
|
||||
:rtype: bool
|
||||
"""
|
||||
return parallel_mode in self._groups
|
||||
|
||||
def destroy(self):
|
||||
"""Destroys the current distributed parallel environment.
|
||||
"""
|
||||
for mode, group in self._groups.items():
|
||||
if mode is not ParallelMode.GLOBAL:
|
||||
dist.destroy_process_group(group)
|
||||
# destroy global process group
|
||||
dist.destroy_process_group()
|
||||
|
||||
def set_device(self):
|
||||
"""Sets distributed processes to be bound to devices.
|
||||
"""
|
||||
devices_per_node = torch.cuda.device_count()
|
||||
global_rank = self.get_global_rank()
|
||||
device = global_rank % devices_per_node
|
||||
torch.cuda.set_device(device)
|
||||
print(f'process rank {global_rank} is bound to device {device}')
|
||||
|
||||
def set_seed(self):
|
||||
"""Sets seeds for all random libraries.
|
||||
"""
|
||||
if hasattr(self.config, 'seed'):
|
||||
seed = getattr(self.config, 'seed')
|
||||
else:
|
||||
seed = 2 # default seed
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
global_rank = self.get_global_rank()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# create random seed for different parallel modes
|
||||
# data parallel seed are kept the same
|
||||
parallel_seed = seed
|
||||
add_seed(ParallelMode.DATA, parallel_seed)
|
||||
|
||||
# model parallel seeds are different across ranks
|
||||
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
|
||||
|
||||
# add seed for data parallel and tensor parallel only
|
||||
if self.is_initialized(ParallelMode.TENSOR):
|
||||
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
|
||||
# 100 is only to increase the diff in seeds between pipeline stages
|
||||
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
|
||||
tp_seed = seed + tp_rank_with_offset
|
||||
add_seed(ParallelMode.TENSOR, tp_seed)
|
||||
|
||||
set_mode(ParallelMode.DATA)
|
||||
seeds = get_seeds()
|
||||
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
|
||||
|
||||
print(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.", flush=True)
|
||||
else:
|
||||
print(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True)
|
||||
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
|
||||
flush=True)
|
||||
Reference in New Issue
Block a user