mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
fix format parallel_context.py (#359)
Co-authored-by: huangziyu <202476410arsmart@gmail.com>
This commit is contained in:
parent
c695369af0
commit
a77d73f22b
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode
|
|||||||
|
|
||||||
|
|
||||||
class ParallelContext:
|
class ParallelContext:
|
||||||
"""This class provides interface functions for users to get the parallel context,
|
"""This class provides interface functions for users to get the parallel context,
|
||||||
such as the global rank, the local rank, the world size, etc. of each device.
|
such as the global rank, the local rank, the world size, etc. of each device.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -218,7 +217,8 @@ class ParallelContext:
|
|||||||
|
|
||||||
def is_pipeline_last_stage(self, ignore_virtual=False):
|
def is_pipeline_last_stage(self, ignore_virtual=False):
|
||||||
if not ignore_virtual:
|
if not ignore_virtual:
|
||||||
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
|
if self.virtual_pipeline_parallel_size \
|
||||||
|
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
|
||||||
return False
|
return False
|
||||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
@ -300,13 +300,7 @@ class ParallelContext:
|
|||||||
self._check_parallel_mode(parallel_mode)
|
self._check_parallel_mode(parallel_mode)
|
||||||
self._ranks_in_group[parallel_mode] = ranks
|
self._ranks_in_group[parallel_mode] = ranks
|
||||||
|
|
||||||
def init_global_dist(self,
|
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
backend: str,
|
|
||||||
host: str,
|
|
||||||
port: int
|
|
||||||
):
|
|
||||||
"""Initializes the global distributed environment
|
"""Initializes the global distributed environment
|
||||||
:param rank: rank for the default process group
|
:param rank: rank for the default process group
|
||||||
:type rank: int
|
:type rank: int
|
||||||
@ -321,18 +315,13 @@ class ParallelContext:
|
|||||||
"""
|
"""
|
||||||
# initialize the default process group
|
# initialize the default process group
|
||||||
init_method = f'tcp://{host}:{port}'
|
init_method = f'tcp://{host}:{port}'
|
||||||
dist.init_process_group(rank=rank,
|
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
init_method=init_method)
|
|
||||||
|
|
||||||
# None will give the default global process group for pytorch dist operations
|
# None will give the default global process group for pytorch dist operations
|
||||||
self._register_dist(rank, world_size, None,
|
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
|
||||||
list(range(world_size)), ParallelMode.GLOBAL)
|
|
||||||
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
||||||
|
|
||||||
def _register_dist(self, local_rank, world_size,
|
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
|
||||||
process_group, ranks_in_group, mode):
|
|
||||||
self.add_local_rank(mode, local_rank)
|
self.add_local_rank(mode, local_rank)
|
||||||
self.add_world_size(mode, world_size)
|
self.add_world_size(mode, world_size)
|
||||||
self.add_group(mode, process_group)
|
self.add_group(mode, process_group)
|
||||||
@ -349,7 +338,9 @@ class ParallelContext:
|
|||||||
tps = self.tensor_parallel_size
|
tps = self.tensor_parallel_size
|
||||||
ws = self.world_size
|
ws = self.world_size
|
||||||
assert ws == dps * pps * \
|
assert ws == dps * pps * \
|
||||||
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
|
tps, f"Expected the world size {ws} to be equal to data" \
|
||||||
|
f" parallel size ({dps}) * pipeline parallel size " \
|
||||||
|
f"({pps}) * tensor parallel size ({tps})"
|
||||||
|
|
||||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||||
if key in config:
|
if key in config:
|
||||||
@ -360,8 +351,7 @@ class ParallelContext:
|
|||||||
setattr(self, attr_name, ele['size'])
|
setattr(self, attr_name, ele['size'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Parallel configuration does not support this kind of argument, please use int or dict"
|
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
|
||||||
)
|
|
||||||
|
|
||||||
def init_parallel_groups(self):
|
def init_parallel_groups(self):
|
||||||
"""Initializes the parallel groups.
|
"""Initializes the parallel groups.
|
||||||
@ -386,11 +376,13 @@ class ParallelContext:
|
|||||||
|
|
||||||
# get the tensor parallel mode and check
|
# get the tensor parallel mode and check
|
||||||
tensor_parallel_mode = None
|
tensor_parallel_mode = None
|
||||||
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
if parallel_config is not None and 'tensor' in \
|
||||||
|
parallel_config and 'mode' in parallel_config['tensor']:
|
||||||
tensor_parallel_mode = parallel_config['tensor']['mode']
|
tensor_parallel_mode = parallel_config['tensor']['mode']
|
||||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
assert tensor_parallel_mode in ALLOWED_MODES, \
|
||||||
|
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||||
env.mode = tensor_parallel_mode
|
env.mode = tensor_parallel_mode
|
||||||
|
|
||||||
self.check_sanity()
|
self.check_sanity()
|
||||||
|
|
||||||
pg_init = []
|
pg_init = []
|
||||||
@ -426,12 +418,10 @@ class ParallelContext:
|
|||||||
for initializer_cfg in pg_init:
|
for initializer_cfg in pg_init:
|
||||||
cfg = initializer_cfg.copy()
|
cfg = initializer_cfg.copy()
|
||||||
initializer_type = cfg.pop('type')
|
initializer_type = cfg.pop('type')
|
||||||
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
|
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
|
||||||
rank, world_size, self.config,
|
self.data_parallel_size,
|
||||||
self.data_parallel_size,
|
self.pipeline_parallel_size,
|
||||||
self.pipeline_parallel_size,
|
self.tensor_parallel_size, **cfg)
|
||||||
self.tensor_parallel_size,
|
|
||||||
**cfg)
|
|
||||||
parallel_setting = initializer.init_dist_group()
|
parallel_setting = initializer.init_dist_group()
|
||||||
if isinstance(parallel_setting, list):
|
if isinstance(parallel_setting, list):
|
||||||
for args in parallel_setting:
|
for args in parallel_setting:
|
||||||
@ -509,10 +499,9 @@ class ParallelContext:
|
|||||||
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
|
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
|
||||||
|
|
||||||
if self._verbose:
|
if self._verbose:
|
||||||
self._logger.info(
|
self._logger.info(f"initialized seed on rank {global_rank}, "
|
||||||
f"initialized seed on rank {global_rank}, "
|
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
f"the default parallel seed is {ParallelMode.DATA}.")
|
||||||
f"the default parallel seed is {ParallelMode.DATA}.")
|
|
||||||
else:
|
else:
|
||||||
if self._verbose:
|
if self._verbose:
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
|
Loading…
Reference in New Issue
Block a user