From 7ef91606e17cc1e991496c6cc74f73cbd42313ae Mon Sep 17 00:00:00 2001 From: Season Date: Thu, 25 Apr 2024 14:45:52 +0800 Subject: [PATCH] [Fix]: implement thread-safety singleton to avoid deadlock for very large-scale training scenarios (#5625) * implement thread-safety singleton * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor singleton implementation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/context/singleton_meta.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py index 3088b0dff..86a8aa5d8 100644 --- a/colossalai/context/singleton_meta.py +++ b/colossalai/context/singleton_meta.py @@ -1,22 +1,27 @@ +import threading + + class SingletonMeta(type): """ - The Singleton class can be implemented in different ways in Python. Some - possible methods include: base class, decorator, metaclass. We will use the - metaclass because it is best suited for this purpose. + Thread-safe Singleton Meta with double-checked locking. + Reference: https://en.wikipedia.org/wiki/Double-checked_locking """ _instances = {} + _lock = threading.Lock() def __call__(cls, *args, **kwargs): - """ - Possible changes to the value of the `__init__` argument do not affect - the returned instance. - """ + # First check (without locking) for performance reasons if cls not in cls._instances: - instance = super().__call__(*args, **kwargs) - cls._instances[cls] = instance + # Acquire a lock before proceeding to the second check + with cls._lock: + # Second check with lock held to ensure thread safety + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance else: assert ( len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." + ), f"{cls.__name__} is a singleton class and an instance has been created." + return cls._instances[cls]