mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
@@ -6,8 +6,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class DistributedLogger:
|
||||
@@ -63,6 +62,7 @@ class DistributedLogger:
|
||||
self._logger.propagate = False
|
||||
|
||||
DistributedLogger.__instances[name] = self
|
||||
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
@staticmethod
|
||||
def __get_call_info():
|
||||
@@ -109,16 +109,10 @@ class DistributedLogger:
|
||||
# create log directory
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# set the default file name if path is a directory
|
||||
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
|
||||
rank = 0
|
||||
else:
|
||||
rank = colossalai.core.global_context.get_global_rank()
|
||||
|
||||
if suffix is not None:
|
||||
log_file_name = f'rank_{rank}_{suffix}.log'
|
||||
log_file_name = f'rank_{self.rank}_{suffix}.log'
|
||||
else:
|
||||
log_file_name = f'rank_{rank}.log'
|
||||
log_file_name = f'rank_{self.rank}.log'
|
||||
path = path.joinpath(log_file_name)
|
||||
|
||||
# add file handler
|
||||
@@ -128,19 +122,14 @@ class DistributedLogger:
|
||||
file_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(file_handler)
|
||||
|
||||
def _log(self,
|
||||
level,
|
||||
message: str,
|
||||
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
|
||||
ranks: List[int] = None) -> None:
|
||||
def _log(self, level, message: str, ranks: List[int] = None) -> None:
|
||||
if ranks is None:
|
||||
getattr(self._logger, level)(message)
|
||||
else:
|
||||
local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
|
||||
if local_rank in ranks:
|
||||
if self.rank in ranks:
|
||||
getattr(self._logger, level)(message)
|
||||
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
def info(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log an info message.
|
||||
|
||||
Args:
|
||||
@@ -150,10 +139,10 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('info', message_prefix, parallel_mode, ranks)
|
||||
self._log('info', message, parallel_mode, ranks)
|
||||
self._log('info', message_prefix, ranks)
|
||||
self._log('info', message, ranks)
|
||||
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
def warning(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log a warning message.
|
||||
|
||||
Args:
|
||||
@@ -163,10 +152,10 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('warning', message_prefix, parallel_mode, ranks)
|
||||
self._log('warning', message, parallel_mode, ranks)
|
||||
self._log('warning', message_prefix, ranks)
|
||||
self._log('warning', message, ranks)
|
||||
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
def debug(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log a debug message.
|
||||
|
||||
Args:
|
||||
@@ -176,10 +165,10 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('debug', message_prefix, parallel_mode, ranks)
|
||||
self._log('debug', message, parallel_mode, ranks)
|
||||
self._log('debug', message_prefix, ranks)
|
||||
self._log('debug', message, ranks)
|
||||
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
def error(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log an error message.
|
||||
|
||||
Args:
|
||||
@@ -189,5 +178,5 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('error', message_prefix, parallel_mode, ranks)
|
||||
self._log('error', message, parallel_mode, ranks)
|
||||
self._log('error', message_prefix, ranks)
|
||||
self._log('error', message, ranks)
|
||||
|
Reference in New Issue
Block a user