[polish] polish singleton and global context (#500)

This commit is contained in:
Jiarui Fang
2022-03-23 18:03:39 +08:00
committed by GitHub
parent 9ec1ce6ab1
commit a445e118cf
18 changed files with 39 additions and 47 deletions

View File

@@ -1,6 +1,6 @@
from .config import Config, ConfigException
from .parallel_context import ParallelContext
from .moe_context import MoeContext
from .parallel_mode import ParallelMode
from .moe_context import MOE_CONTEXT
from .process_group_initializer import *
from .random import *

View File

@@ -1,6 +1,9 @@
import torch
import torch.distributed as dist
from .parallel_mode import ParallelMode
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from typing import Tuple
@@ -56,17 +59,10 @@ class MoeParallelInfo:
self.dp_group = group
class MoeContext:
class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
__instance = None
@staticmethod
def get_instance():
if MoeContext.__instance is None:
MoeContext.__instance = MoeContext()
return MoeContext.__instance
def __init__(self):
self.world_size = 1
@@ -160,3 +156,6 @@ class MoeContext:
def get_loss(self):
return self.aux_loss
MOE_CONTEXT = MoeContext()

View File

@@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
from colossalai.context.singleton_meta import SingletonMeta
class ParallelContext:
class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
"""
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
@@ -510,3 +496,6 @@ class ParallelContext:
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
global_context = ParallelContext()

View File

@@ -0,0 +1,18 @@
class SingletonMeta(type):
"""
The Singleton class can be implemented in different ways in Python. Some
possible methods include: base class, decorator, metaclass. We will use the
metaclass because it is best suited for this purpose.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]