diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py index 937355ef8..97fe4f89d 100644 --- a/colossalai/logging/__init__.py +++ b/colossalai/logging/__init__.py @@ -6,22 +6,20 @@ from .logger import DistributedLogger __all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] -def get_dist_logger(name='colossalai'): +def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: """Get logger instance based on name. The DistributedLogger will create singleton instances, which means that only one logger instance is created per name. Args: - - :param name: name of the logger, name must be unique - :type name: str - - :return: a distributed logger instance - :rtype: :class:`colossalai.logging.DistributedLogger` + name (str): name of the logger, name must be unique + + Returns: + :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. """ return DistributedLogger.get_instance(name=name) -def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']): +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None: """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". Args: diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index ddd7ec5c9..acfc73c2d 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -4,7 +4,7 @@ import colossalai import logging from pathlib import Path -from typing import Union +from typing import Union, List import inspect from colossalai.context.parallel_mode import ParallelMode @@ -40,6 +40,7 @@ class DistributedLogger: Args: name (str): The name of the logger. + Returns: DistributedLogger: A DistributedLogger object """ @@ -75,7 +76,7 @@ class DistributedLogger: def _check_valid_logging_level(level: str): assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' - def set_level(self, level: str): + def set_level(self, level: str) -> None: """Set the logging level Args: @@ -84,7 +85,7 @@ class DistributedLogger: self._check_valid_logging_level(level) self._logger.setLevel(getattr(logging, level)) - def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None): + def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None: """Save the logs to file Args: @@ -122,7 +123,11 @@ class DistributedLogger: file_handler.setFormatter(formatter) self._logger.addHandler(file_handler) - def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): + def _log(self, + level, + message: str, + parallel_mode: ParallelMode = ParallelMode.GLOBAL, + ranks: List[int] = None) -> None: if ranks is None: getattr(self._logger, level)(message) else: @@ -130,53 +135,53 @@ class DistributedLogger: if local_rank in ranks: getattr(self._logger, level)(message) - def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): + def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: """Log an info message. Args: message (str): The message to be logged. parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. - ranks (List): List of parallel ranks. + ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) self._log('info', message_prefix, parallel_mode, ranks) self._log('info', message, parallel_mode, ranks) - def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): + def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: """Log a warning message. Args: message (str): The message to be logged. parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. - ranks (List): List of parallel ranks. + ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) self._log('warning', message_prefix, parallel_mode, ranks) self._log('warning', message, parallel_mode, ranks) - def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): + def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: """Log a debug message. Args: message (str): The message to be logged. parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. - ranks (List): List of parallel ranks. + ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) self._log('debug', message_prefix, parallel_mode, ranks) self._log('debug', message, parallel_mode, ranks) - def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): + def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: """Log an error message. Args: message (str): The message to be logged. parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. - ranks (List): List of parallel ranks. + ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) self._log('error', message_prefix, parallel_mode, ranks)