mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[polish] polish singleton and global context (#500)
This commit is contained in:
@@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
class ParallelContext(metaclass=SingletonMeta):
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
"""
|
||||
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if ParallelContext.__instance is None:
|
||||
ParallelContext()
|
||||
return ParallelContext.__instance
|
||||
|
||||
def __init__(self):
|
||||
# create a singleton instance
|
||||
if ParallelContext.__instance is not None:
|
||||
raise Exception(
|
||||
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
|
||||
else:
|
||||
ParallelContext.__instance = self
|
||||
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
@@ -510,3 +496,6 @@ class ParallelContext:
|
||||
|
||||
def set_virtual_pipeline_parallel_rank(self, rank):
|
||||
self.virtual_pipeline_parallel_rank = rank
|
||||
|
||||
|
||||
global_context = ParallelContext()
|
||||
|
||||
Reference in New Issue
Block a user