[logging] polish logger format (#543)

This commit is contained in:
Jiarui Fang
2022-03-29 10:37:11 +08:00
committed by GitHub
parent 1f90a3b129
commit 7d81b5b46e

View File

@@ -5,15 +5,18 @@ import colossalai
import logging
from pathlib import Path
from typing import Union
import inspect
from colossalai.context.parallel_mode import ParallelMode
try:
from rich.logging import RichHandler
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT, handlers=[RichHandler()])
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO,
format=_FORMAT,
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
except ImportError:
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
@@ -50,6 +53,19 @@ class DistributedLogger:
self._logger = logging.getLogger(name)
DistributedLogger.__instances[name] = self
@staticmethod
def __get_call_info():
stack = inspect.stack()
# stack[1] gives previous function ('info' in our case)
# stack[2] gives before previous function and so on
fn = stack[2][1]
ln = stack[2][2]
func = stack[2][3]
return fn, ln, func
@staticmethod
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
@@ -122,6 +138,8 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
@@ -134,6 +152,8 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
@@ -146,6 +166,8 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
@@ -158,4 +180,6 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks)
self._log('error', message, parallel_mode, ranks)