mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[utils] support detection of number of processes on current node (#723)
This commit is contained in:
@@ -102,6 +102,9 @@ def launch(config: Union[str, Path, Config, Dict],
|
||||
# if local rank is not given, calculate automatically
|
||||
gpc.set_device(local_rank)
|
||||
|
||||
# set the number of processes running on the same node
|
||||
gpc.detect_num_processes_on_current_node()
|
||||
|
||||
gpc.set_seed(seed)
|
||||
|
||||
if verbose:
|
||||
@@ -398,15 +401,17 @@ def initialize(model: nn.Module,
|
||||
else:
|
||||
scatter_gather = False
|
||||
if use_interleaved:
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = NonPipelineSchedule()
|
||||
|
||||
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
if verbose and not isinstance(model, DDP):
|
||||
|
Reference in New Issue
Block a user