[utils] support detection of number of processes on current node (#723)

This commit is contained in:
Frank Lee
2022-04-12 09:28:19 +08:00
committed by GitHub
parent 4d90a7b513
commit 04ff5ea546
2 changed files with 19 additions and 4 deletions

View File

@@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*-
import random
import socket
from collections import Counter
from typing import Union
import numpy as np
@@ -45,6 +47,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
@@ -81,6 +84,13 @@ class ParallelContext(metaclass=SingletonMeta):
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
def detect_num_processes_on_current_node(self):
hostname = socket.gethostname()
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
counter = Counter(hostname_list)
self.num_processes_on_current_node = counter[hostname]
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)