[polish] polish singleton and global context (#500)

This commit is contained in:
Jiarui Fang
2022-03-23 18:03:39 +08:00
committed by GitHub
parent 9ec1ce6ab1
commit a445e118cf
18 changed files with 39 additions and 47 deletions

View File

@@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
from colossalai.context.singleton_meta import SingletonMeta
class ParallelContext:
class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
"""
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
@@ -510,3 +496,6 @@ class ParallelContext:
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
global_context = ParallelContext()