mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,23 +3,23 @@ from typing import List, Optional
|
||||
|
||||
from .logger import DistributedLogger
|
||||
|
||||
__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers']
|
||||
__all__ = ["get_dist_logger", "DistributedLogger", "disable_existing_loggers"]
|
||||
|
||||
|
||||
def get_dist_logger(name: str = 'colossalai') -> DistributedLogger:
|
||||
def get_dist_logger(name: str = "colossalai") -> DistributedLogger:
|
||||
"""Get logger instance based on name. The DistributedLogger will create singleton instances,
|
||||
which means that only one logger instance is created per name.
|
||||
|
||||
Args:
|
||||
name (str): name of the logger, name must be unique
|
||||
|
||||
|
||||
Returns:
|
||||
:class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance.
|
||||
"""
|
||||
return DistributedLogger.get_instance(name=name)
|
||||
|
||||
|
||||
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None:
|
||||
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ["colossalai"]) -> None:
|
||||
"""Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai".
|
||||
|
||||
Args:
|
||||
|
@@ -42,12 +42,14 @@ class DistributedLogger:
|
||||
def __init__(self, name):
|
||||
if name in DistributedLogger.__instances:
|
||||
raise Exception(
|
||||
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
|
||||
"Logger with the same name has been created, you should use colossalai.logging.get_dist_logger"
|
||||
)
|
||||
else:
|
||||
handler = None
|
||||
formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s')
|
||||
formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s")
|
||||
try:
|
||||
from rich.logging import RichHandler
|
||||
|
||||
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
|
||||
handler.setFormatter(formatter)
|
||||
except ImportError:
|
||||
@@ -79,7 +81,7 @@ class DistributedLogger:
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_logging_level(level: str):
|
||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
||||
assert level in ["INFO", "DEBUG", "WARNING", "ERROR"], "found invalid logging level"
|
||||
|
||||
def set_level(self, level: str) -> None:
|
||||
"""Set the logging level
|
||||
@@ -90,7 +92,7 @@ class DistributedLogger:
|
||||
self._check_valid_logging_level(level)
|
||||
self._logger.setLevel(getattr(logging, level))
|
||||
|
||||
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
|
||||
def log_to_file(self, path: Union[str, Path], mode: str = "a", level: str = "INFO", suffix: str = None) -> None:
|
||||
"""Save the logs to file
|
||||
|
||||
Args:
|
||||
@@ -99,8 +101,7 @@ class DistributedLogger:
|
||||
level (str): Can only be INFO, DEBUG, WARNING and ERROR.
|
||||
suffix (str): The suffix string of log's name.
|
||||
"""
|
||||
assert isinstance(path, (str, Path)), \
|
||||
f'expected argument path to be type str or Path, but got {type(path)}'
|
||||
assert isinstance(path, (str, Path)), f"expected argument path to be type str or Path, but got {type(path)}"
|
||||
self._check_valid_logging_level(level)
|
||||
|
||||
if isinstance(path, str):
|
||||
@@ -110,15 +111,15 @@ class DistributedLogger:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if suffix is not None:
|
||||
log_file_name = f'rank_{self.rank}_{suffix}.log'
|
||||
log_file_name = f"rank_{self.rank}_{suffix}.log"
|
||||
else:
|
||||
log_file_name = f'rank_{self.rank}.log'
|
||||
log_file_name = f"rank_{self.rank}.log"
|
||||
path = path.joinpath(log_file_name)
|
||||
|
||||
# add file handler
|
||||
file_handler = logging.FileHandler(path, mode)
|
||||
file_handler.setLevel(getattr(logging, level))
|
||||
formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s')
|
||||
formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(file_handler)
|
||||
|
||||
@@ -137,8 +138,8 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('info', message_prefix, ranks)
|
||||
self._log('info', message, ranks)
|
||||
self._log("info", message_prefix, ranks)
|
||||
self._log("info", message, ranks)
|
||||
|
||||
def warning(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log a warning message.
|
||||
@@ -148,8 +149,8 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('warning', message_prefix, ranks)
|
||||
self._log('warning', message, ranks)
|
||||
self._log("warning", message_prefix, ranks)
|
||||
self._log("warning", message, ranks)
|
||||
|
||||
def debug(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log a debug message.
|
||||
@@ -159,8 +160,8 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('debug', message_prefix, ranks)
|
||||
self._log('debug', message, ranks)
|
||||
self._log("debug", message_prefix, ranks)
|
||||
self._log("debug", message, ranks)
|
||||
|
||||
def error(self, message: str, ranks: List[int] = None) -> None:
|
||||
"""Log an error message.
|
||||
@@ -170,5 +171,5 @@ class DistributedLogger:
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('error', message_prefix, ranks)
|
||||
self._log('error', message, ranks)
|
||||
self._log("error", message_prefix, ranks)
|
||||
self._log("error", message, ranks)
|
||||
|
Reference in New Issue
Block a user