[hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit
This commit is contained in:
Hongxin Liu
2023-10-18 11:05:25 +08:00
committed by GitHub
parent 21ba89cab6
commit 1f5d2e8062
6 changed files with 39 additions and 55 deletions

View File

@@ -54,7 +54,7 @@ class ParallelContext(metaclass=SingletonMeta):
# logging
self._verbose = False
self._logger = get_dist_logger()
self._logger = None
@property
def config(self):
@@ -68,6 +68,12 @@ class ParallelContext(metaclass=SingletonMeta):
def verbose(self, verbose_: bool):
self._verbose = verbose_
@property
def logger(self):
if self._logger is None:
self._logger = get_dist_logger()
return self._logger
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
@@ -527,7 +533,7 @@ class ParallelContext(metaclass=SingletonMeta):
torch.cuda.set_device(device_ordinal)
if self._verbose:
self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")
self.logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")
def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
@@ -563,19 +569,19 @@ class ParallelContext(metaclass=SingletonMeta):
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
if self._verbose:
self._logger.info(
self.logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}."
)
else:
if self._verbose:
self._logger.info(
self.logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0],
)
self._logger.info(
self.logger.info(
"WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states",
ranks=[0],
)

View File

@@ -31,7 +31,7 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
return self.dict[processgroup_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
PYTORCHPGDICT_ = None
class ProcessGroup:
@@ -59,6 +59,9 @@ class ProcessGroup:
if not torch.distributed.is_initialized():
self.is_init = False
return
global PYTORCHPGDICT_
if PYTORCHPGDICT_ is None:
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"