mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[utils] support detection of number of processes on current node (#723)
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import random
|
||||
import socket
|
||||
from collections import Counter
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -45,6 +47,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
self.data_parallel_size = 1
|
||||
self.pipeline_parallel_size = 1
|
||||
self.tensor_parallel_size = 1
|
||||
self.num_processes_on_current_node = -1
|
||||
self.virtual_pipeline_parallel_size = None
|
||||
self.virtual_pipeline_parallel_rank = None
|
||||
|
||||
@@ -81,6 +84,13 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
else:
|
||||
raise TypeError("Invalid type for config, only dictionary or string is supported")
|
||||
|
||||
def detect_num_processes_on_current_node(self):
|
||||
hostname = socket.gethostname()
|
||||
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
|
||||
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
|
||||
counter = Counter(hostname_list)
|
||||
self.num_processes_on_current_node = counter[hostname]
|
||||
|
||||
@staticmethod
|
||||
def _check_parallel_mode(parallel_mode: ParallelMode):
|
||||
assert isinstance(parallel_mode, ParallelMode)
|
||||
|
Reference in New Issue
Block a user