[utils] support detection of number of processes on current node (#723)

This commit is contained in:
Frank Lee
2022-04-12 09:28:19 +08:00
committed by GitHub
parent 4d90a7b513
commit 04ff5ea546
2 changed files with 19 additions and 4 deletions

View File

@@ -102,6 +102,9 @@ def launch(config: Union[str, Path, Config, Dict],
# if local rank is not given, calculate automatically
gpc.set_device(local_rank)
# set the number of processes running on the same node
gpc.detect_num_processes_on_current_node()
gpc.set_seed(seed)
if verbose:
@@ -398,15 +401,17 @@ def initialize(model: nn.Module,
else:
scatter_gather = False
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()
if gradient_handler_cfg is None:
gradient_handlers = None
if verbose and not isinstance(model, DDP):