mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
moved env variables to global variables; (#215)
added branch context; added vocab parallel layers; moved split_batch from load_batch to tensor parallel embedding layers; updated gpt model; updated unit test cases; fixed few collective communicator bugs
This commit is contained in:
@@ -8,14 +8,15 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
@@ -307,7 +308,6 @@ class ParallelContext:
|
||||
port: int
|
||||
):
|
||||
"""Initializes the global distributed environment
|
||||
|
||||
:param rank: rank for the default process group
|
||||
:type rank: int
|
||||
:param world_size: world size of the default process group
|
||||
@@ -389,7 +389,8 @@ class ParallelContext:
|
||||
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
||||
tensor_parallel_mode = parallel_config['tensor']['mode']
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
|
||||
env.mode = tensor_parallel_mode
|
||||
|
||||
self.check_sanity()
|
||||
|
||||
pg_init = []
|
||||
|
@@ -1,22 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import Config
|
||||
import torch.distributed as dist
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
from ..parallel_mode import ParallelMode
|
||||
from colossalai.constants import PARALLEL_INPUT_1D
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_1D(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for 1d tensor parallelism.
|
||||
|
||||
:param args: Args used to initialize ProcessGroupInitializer
|
||||
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
|
||||
"""
|
||||
'''A ProcessGroupInitializer for 1d tensor parallelism.
|
||||
'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -24,7 +20,7 @@ class Initializer_1D(ProcessGroupInitializer):
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
|
||||
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
@@ -33,7 +29,7 @@ class Initializer_1D(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_1D
|
||||
os.environ[PARALLEL_INPUT_1D] = ''
|
||||
env.parallel_input_1d = False
|
||||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
|
@@ -1,34 +1,31 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import SUMMA_DIM
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
|
||||
def _check_summa_env_var(summa_dim):
|
||||
# check environment variable for SUMMA
|
||||
env_summa_dim = os.environ.get(SUMMA_DIM, None)
|
||||
env_summa_dim = env.summa_dim
|
||||
|
||||
if env_summa_dim:
|
||||
assert int(env_summa_dim) == summa_dim, \
|
||||
'SUMMA_DIM has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[SUMMA_DIM] = str(summa_dim)
|
||||
env.summa_dim = summa_dim
|
||||
|
||||
|
||||
class Initializer_2D_Row(ProcessGroupInitializer):
|
||||
"""2d tensor parallel initialization among rows.
|
||||
|
||||
:param num_group: The number of all tensor groups
|
||||
:param summa_dim: The dimension of SUMMA
|
||||
:param args: Args used to initialize base class
|
||||
:param kwargs: Kwargs used to initialize base class
|
||||
|
||||
:type num_group: int
|
||||
:type summa_dim: int
|
||||
"""
|
||||
@@ -132,7 +129,7 @@ class Initializer_2D(ProcessGroupInitializer):
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
|
||||
:return: 2D tensor parallelism's information
|
||||
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
"""
|
||||
|
@@ -2,22 +2,21 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
|
||||
from colossalai.context import Config
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
from ..parallel_mode import ParallelMode
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
|
||||
def _check_tesseract_env_var(tesseract_dim: int,
|
||||
tesseract_dep: int):
|
||||
# check environment variable for TESSERACT
|
||||
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None)
|
||||
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None)
|
||||
# check global variable for TESSERACT
|
||||
env_tesseract_dim = env.tesseract_dim
|
||||
env_tesseract_dep = env.tesseract_dep
|
||||
|
||||
if env_tesseract_dim and env_tesseract_dep:
|
||||
assert int(env_tesseract_dim) == tesseract_dim, \
|
||||
@@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int,
|
||||
'TESSERACT_DEP has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[TESSERACT_DIM] = str(tesseract_dim)
|
||||
os.environ[TESSERACT_DEP] = str(tesseract_dep)
|
||||
env.tesseract_dim = tesseract_dim
|
||||
env.tesseract_dep = tesseract_dep
|
||||
|
||||
|
||||
# i row j col k dep
|
||||
@@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
|
||||
:param pipeline_parallel_size: Size of pipeline parallel
|
||||
:param tensor_parallel_size: Size of tensor parallel
|
||||
:param depth: The depth of 2p5d parallel
|
||||
|
||||
:type rank: int
|
||||
:type world_size: int
|
||||
:type config: Config
|
||||
@@ -281,7 +279,7 @@ class Initializer_2p5D(ProcessGroupInitializer):
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
|
||||
:return: Whole 2p5D tensor parallelism's information
|
||||
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
"""
|
||||
|
@@ -2,10 +2,9 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from ..parallel_mode import ParallelMode
|
||||
@@ -13,15 +12,15 @@ from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
|
||||
def _check_depth_env_var(depth):
|
||||
# check environment variable for SUMMA
|
||||
env_depth = os.environ.get(DEPTH_3D, None)
|
||||
# check global variable
|
||||
env_depth = env.depth_3d
|
||||
|
||||
if env_depth:
|
||||
assert int(env_depth) == depth, \
|
||||
'DEPTH_3D has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[DEPTH_3D] = str(depth)
|
||||
env.depth_3d = depth
|
||||
|
||||
|
||||
class Initializer_3D_Input(ProcessGroupInitializer):
|
||||
@@ -34,6 +33,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
|
||||
:type num_group: int
|
||||
:type depth: int
|
||||
"""
|
||||
|
||||
def __init__(self, num_group: int, depth: int, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = num_group
|
||||
@@ -50,15 +50,12 @@ class Initializer_3D_Input(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
|
||||
env.input_group_3d = mode
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.depth):
|
||||
for k in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth *
|
||||
(j + self.depth * k) for j in range(self.depth)
|
||||
]
|
||||
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
@@ -97,15 +94,12 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
|
||||
env.weight_group_3d = mode
|
||||
|
||||
for h in range(self.num_group):
|
||||
for k in range(self.depth):
|
||||
for j in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth *
|
||||
(j + self.depth * k) for i in range(self.depth)
|
||||
]
|
||||
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
@@ -118,7 +112,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
|
||||
|
||||
|
||||
class Initializer_3D_Output(ProcessGroupInitializer):
|
||||
"""3D tensor parallel initialization among weight.
|
||||
"""3D tensor parallel initialization among output.
|
||||
|
||||
:param num_group: The number of all tensor groups
|
||||
:param depth: Depth of 3D parallelism
|
||||
@@ -144,15 +138,12 @@ class Initializer_3D_Output(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
|
||||
env.output_group_3d = mode
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.depth):
|
||||
for j in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth *
|
||||
(j + self.depth * k) for k in range(self.depth)
|
||||
]
|
||||
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
@@ -170,6 +161,7 @@ class Initializer_3D(ProcessGroupInitializer):
|
||||
|
||||
:param args: Args used to initialize ProcessGroupInitializer
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
@@ -178,16 +170,13 @@ class Initializer_3D(ProcessGroupInitializer):
|
||||
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
|
||||
_check_depth_env_var(self.depth)
|
||||
|
||||
self.input_initializer = Initializer_3D_Input(self.num_group,
|
||||
self.depth, *args)
|
||||
self.weight_initializer = Initializer_3D_Weight(
|
||||
self.num_group, self.depth, *args)
|
||||
self.output_initializer = Initializer_3D_Output(
|
||||
self.num_group, self.depth, *args)
|
||||
self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
|
||||
self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
|
||||
self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
|
||||
:return: 3D tensor parallelism's information
|
||||
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
"""
|
||||
|
Reference in New Issue
Block a user