mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 08:34:14 +00:00
[hotfix] fix torch 2.0 compatibility (#4936)
* [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit
This commit is contained in:
@@ -54,7 +54,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
|
||||
# logging
|
||||
self._verbose = False
|
||||
self._logger = get_dist_logger()
|
||||
self._logger = None
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
@@ -68,6 +68,12 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
def verbose(self, verbose_: bool):
|
||||
self._verbose = verbose_
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
if self._logger is None:
|
||||
self._logger = get_dist_logger()
|
||||
return self._logger
|
||||
|
||||
def load_config(self, config: Union[dict, str]):
|
||||
"""Loads the configuration from either a dict or a file.
|
||||
|
||||
@@ -527,7 +533,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
|
||||
torch.cuda.set_device(device_ordinal)
|
||||
if self._verbose:
|
||||
self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")
|
||||
self.logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")
|
||||
|
||||
def set_seed(self, seed: int):
|
||||
"""Sets seeds for all random libraries.
|
||||
@@ -563,19 +569,19 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
|
||||
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
self.logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}."
|
||||
)
|
||||
else:
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
self.logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info(
|
||||
self.logger.info(
|
||||
"WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states",
|
||||
ranks=[0],
|
||||
)
|
||||
|
@@ -31,7 +31,7 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
return self.dict[processgroup_key]
|
||||
|
||||
|
||||
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
||||
PYTORCHPGDICT_ = None
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
@@ -59,6 +59,9 @@ class ProcessGroup:
|
||||
if not torch.distributed.is_initialized():
|
||||
self.is_init = False
|
||||
return
|
||||
global PYTORCHPGDICT_
|
||||
if PYTORCHPGDICT_ is None:
|
||||
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
||||
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
|
||||
|
Reference in New Issue
Block a user