mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[polish] polish singleton and global context (#500)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from .config import Config, ConfigException
|
||||
from .parallel_context import ParallelContext
|
||||
from .moe_context import MoeContext
|
||||
from .parallel_mode import ParallelMode
|
||||
from .moe_context import MOE_CONTEXT
|
||||
from .process_group_initializer import *
|
||||
from .random import *
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from .parallel_mode import ParallelMode
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -56,17 +59,10 @@ class MoeParallelInfo:
|
||||
self.dp_group = group
|
||||
|
||||
|
||||
class MoeContext:
|
||||
class MoeContext(metaclass=SingletonMeta):
|
||||
"""MoE parallel context manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if MoeContext.__instance is None:
|
||||
MoeContext.__instance = MoeContext()
|
||||
return MoeContext.__instance
|
||||
|
||||
def __init__(self):
|
||||
self.world_size = 1
|
||||
@@ -160,3 +156,6 @@ class MoeContext:
|
||||
|
||||
def get_loss(self):
|
||||
return self.aux_loss
|
||||
|
||||
|
||||
MOE_CONTEXT = MoeContext()
|
||||
|
||||
@@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
class ParallelContext(metaclass=SingletonMeta):
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
"""
|
||||
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if ParallelContext.__instance is None:
|
||||
ParallelContext()
|
||||
return ParallelContext.__instance
|
||||
|
||||
def __init__(self):
|
||||
# create a singleton instance
|
||||
if ParallelContext.__instance is not None:
|
||||
raise Exception(
|
||||
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
|
||||
else:
|
||||
ParallelContext.__instance = self
|
||||
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
@@ -510,3 +496,6 @@ class ParallelContext:
|
||||
|
||||
def set_virtual_pipeline_parallel_rank(self, rank):
|
||||
self.virtual_pipeline_parallel_rank = rank
|
||||
|
||||
|
||||
global_context = ParallelContext()
|
||||
|
||||
18
colossalai/context/singleton_meta.py
Normal file
18
colossalai/context/singleton_meta.py
Normal file
@@ -0,0 +1,18 @@
|
||||
class SingletonMeta(type):
|
||||
"""
|
||||
The Singleton class can be implemented in different ways in Python. Some
|
||||
possible methods include: base class, decorator, metaclass. We will use the
|
||||
metaclass because it is best suited for this purpose.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""
|
||||
Possible changes to the value of the `__init__` argument do not affect
|
||||
the returned instance.
|
||||
"""
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
return cls._instances[cls]
|
||||
Reference in New Issue
Block a user