fix format parallel_context.py (#359)

Co-authored-by: huangziyu <202476410arsmart@gmail.com>
This commit is contained in:
ziyu huang 2022-03-10 09:29:32 +08:00 committed by Frank Lee
parent c695369af0
commit a77d73f22b

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import random import random
from typing import Union from typing import Union
@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode
class ParallelContext: class ParallelContext:
"""This class provides interface functions for users to get the parallel context, """This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device. such as the global rank, the local rank, the world size, etc. of each device.
""" """
@ -218,7 +217,8 @@ class ParallelContext:
def is_pipeline_last_stage(self, ignore_virtual=False): def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual: if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: if self.virtual_pipeline_parallel_size \
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False return False
return self.is_last_rank(ParallelMode.PIPELINE) return self.is_last_rank(ParallelMode.PIPELINE)
@ -300,13 +300,7 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
"""Initializes the global distributed environment """Initializes the global distributed environment
:param rank: rank for the default process group :param rank: rank for the default process group
:type rank: int :type rank: int
@ -321,18 +315,13 @@ class ParallelContext:
""" """
# initialize the default process group # initialize the default process group
init_method = f'tcp://{host}:{port}' init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank, dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
world_size=world_size,
backend=backend,
init_method=init_method)
# None will give the default global process group for pytorch dist operations # None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None, self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
list(range(world_size)), ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank) self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
process_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank) self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size) self.add_world_size(mode, world_size)
self.add_group(mode, process_group) self.add_group(mode, process_group)
@ -349,7 +338,9 @@ class ParallelContext:
tps = self.tensor_parallel_size tps = self.tensor_parallel_size
ws = self.world_size ws = self.world_size
assert ws == dps * pps * \ assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})" tps, f"Expected the world size {ws} to be equal to data" \
f" parallel size ({dps}) * pipeline parallel size " \
f"({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config: if key in config:
@ -360,8 +351,7 @@ class ParallelContext:
setattr(self, attr_name, ele['size']) setattr(self, attr_name, ele['size'])
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict" f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
)
def init_parallel_groups(self): def init_parallel_groups(self):
"""Initializes the parallel groups. """Initializes the parallel groups.
@ -386,11 +376,13 @@ class ParallelContext:
# get the tensor parallel mode and check # get the tensor parallel mode and check
tensor_parallel_mode = None tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: if parallel_config is not None and 'tensor' in \
parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode'] tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, \
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode env.mode = tensor_parallel_mode
self.check_sanity() self.check_sanity()
pg_init = [] pg_init = []
@ -426,12 +418,10 @@ class ParallelContext:
for initializer_cfg in pg_init: for initializer_cfg in pg_init:
cfg = initializer_cfg.copy() cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type') initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
rank, world_size, self.config, self.data_parallel_size,
self.data_parallel_size, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size, **cfg)
self.tensor_parallel_size,
**cfg)
parallel_setting = initializer.init_dist_group() parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list): if isinstance(parallel_setting, list):
for args in parallel_setting: for args in parallel_setting:
@ -509,10 +499,9 @@ class ParallelContext:
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
if self._verbose: if self._verbose:
self._logger.info( self._logger.info(f"initialized seed on rank {global_rank}, "
f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, {seed_str},"
f"numpy: {seed}, python random: {seed}, {seed_str}," f"the default parallel seed is {ParallelMode.DATA}.")
f"the default parallel seed is {ParallelMode.DATA}.")
else: else:
if self._verbose: if self._verbose:
self._logger.info( self._logger.info(