diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index acfc73c2d..8d50ee418 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -1,24 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import colossalai +import inspect import logging from pathlib import Path -from typing import Union, List -import inspect +from typing import List, Union +import colossalai from colossalai.context.parallel_mode import ParallelMode -try: - from rich.logging import RichHandler - _FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s' - logging.basicConfig(level=logging.INFO, - format=_FORMAT, - handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)]) -except ImportError: - _FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s' - logging.basicConfig(level=logging.INFO, format=_FORMAT) - class DistributedLogger: """This is a distributed event logger class essentially based on :class:`logging`. @@ -40,7 +30,7 @@ class DistributedLogger: Args: name (str): The name of the logger. - + Returns: DistributedLogger: A DistributedLogger object """ @@ -55,8 +45,23 @@ class DistributedLogger: raise Exception( 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') else: + handler = None + formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + try: + from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) + handler.setFormatter(formatter) + except ImportError: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + self._name = name self._logger = logging.getLogger(name) + self._logger.setLevel(logging.INFO) + if handler is not None: + self._logger.addHandler(handler) + self._logger.propagate = False + DistributedLogger.__instances[name] = self @staticmethod