fix format parallel_context.py (#359)

Co-authored-by: huangziyu <202476410arsmart@gmail.com>
This commit is contained in:
ziyu huang 2022-03-10 09:29:32 +08:00 committed by Frank Lee
parent c695369af0
commit a77d73f22b

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import random import random
from typing import Union from typing import Union
@ -218,7 +217,8 @@ class ParallelContext:
def is_pipeline_last_stage(self, ignore_virtual=False): def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual: if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: if self.virtual_pipeline_parallel_size \
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False return False
return self.is_last_rank(ParallelMode.PIPELINE) return self.is_last_rank(ParallelMode.PIPELINE)
@ -300,13 +300,7 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
"""Initializes the global distributed environment """Initializes the global distributed environment
:param rank: rank for the default process group :param rank: rank for the default process group
:type rank: int :type rank: int
@ -321,18 +315,13 @@ class ParallelContext:
""" """
# initialize the default process group # initialize the default process group
init_method = f'tcp://{host}:{port}' init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank, dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
world_size=world_size,
backend=backend,
init_method=init_method)
# None will give the default global process group for pytorch dist operations # None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None, self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
list(range(world_size)), ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank) self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
process_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank) self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size) self.add_world_size(mode, world_size)
self.add_group(mode, process_group) self.add_group(mode, process_group)
@ -349,7 +338,9 @@ class ParallelContext:
tps = self.tensor_parallel_size tps = self.tensor_parallel_size
ws = self.world_size ws = self.world_size
assert ws == dps * pps * \ assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})" tps, f"Expected the world size {ws} to be equal to data" \
f" parallel size ({dps}) * pipeline parallel size " \
f"({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config: if key in config:
@ -360,8 +351,7 @@ class ParallelContext:
setattr(self, attr_name, ele['size']) setattr(self, attr_name, ele['size'])
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict" f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
)
def init_parallel_groups(self): def init_parallel_groups(self):
"""Initializes the parallel groups. """Initializes the parallel groups.
@ -386,9 +376,11 @@ class ParallelContext:
# get the tensor parallel mode and check # get the tensor parallel mode and check
tensor_parallel_mode = None tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: if parallel_config is not None and 'tensor' in \
parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode'] tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, \
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode env.mode = tensor_parallel_mode
self.check_sanity() self.check_sanity()
@ -426,12 +418,10 @@ class ParallelContext:
for initializer_cfg in pg_init: for initializer_cfg in pg_init:
cfg = initializer_cfg.copy() cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type') initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
rank, world_size, self.config,
self.data_parallel_size, self.data_parallel_size,
self.pipeline_parallel_size, self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size, **cfg)
**cfg)
parallel_setting = initializer.init_dist_group() parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list): if isinstance(parallel_setting, list):
for args in parallel_setting: for args in parallel_setting:
@ -509,8 +499,7 @@ class ParallelContext:
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
if self._verbose: if self._verbose:
self._logger.info( self._logger.info(f"initialized seed on rank {global_rank}, "
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str}," f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.") f"the default parallel seed is {ParallelMode.DATA}.")
else: else: