From 7d81b5b46ebb1e95d32545122bd86b44dcf54db1 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 29 Mar 2022 10:37:11 +0800 Subject: [PATCH] [logging] polish logger format (#543) --- colossalai/logging/logger.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index 5bc2694c6..65dd6ae32 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -5,15 +5,18 @@ import colossalai import logging from pathlib import Path from typing import Union +import inspect from colossalai.context.parallel_mode import ParallelMode try: from rich.logging import RichHandler - _FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s' - logging.basicConfig(level=logging.INFO, format=_FORMAT, handlers=[RichHandler()]) + _FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s' + logging.basicConfig(level=logging.INFO, + format=_FORMAT, + handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)]) except ImportError: - _FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s' + _FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=_FORMAT) @@ -50,6 +53,19 @@ class DistributedLogger: self._logger = logging.getLogger(name) DistributedLogger.__instances[name] = self + @staticmethod + def __get_call_info(): + stack = inspect.stack() + + # stack[1] gives previous function ('info' in our case) + # stack[2] gives before previous function and so on + + fn = stack[2][1] + ln = stack[2][2] + func = stack[2][3] + + return fn, ln, func + @staticmethod def _check_valid_logging_level(level: str): assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' @@ -122,6 +138,8 @@ class DistributedLogger: :param ranks: List of parallel ranks :type ranks: list """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('info', message_prefix, parallel_mode, ranks) self._log('info', message, parallel_mode, ranks) def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): @@ -134,6 +152,8 @@ class DistributedLogger: :param ranks: List of parallel ranks :type ranks: list """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('warning', message_prefix, parallel_mode, ranks) self._log('warning', message, parallel_mode, ranks) def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): @@ -146,6 +166,8 @@ class DistributedLogger: :param ranks: List of parallel ranks :type ranks: list """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('debug', message_prefix, parallel_mode, ranks) self._log('debug', message, parallel_mode, ranks) def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): @@ -158,4 +180,6 @@ class DistributedLogger: :param ranks: List of parallel ranks :type ranks: list """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('error', message_prefix, parallel_mode, ranks) self._log('error', message, parallel_mode, ranks)