[doc] improved docstring in the logging module (#861)

This commit is contained in:
Frank Lee 2022-04-25 13:42:00 +08:00 committed by GitHub
parent 8004c8e938
commit b862d89d00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 20 deletions

View File

@ -6,22 +6,20 @@ from .logger import DistributedLogger
__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] __all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers']
def get_dist_logger(name='colossalai'): def get_dist_logger(name: str = 'colossalai') -> DistributedLogger:
"""Get logger instance based on name. The DistributedLogger will create singleton instances, """Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name. which means that only one logger instance is created per name.
Args: Args:
name (str): name of the logger, name must be unique
:param name: name of the logger, name must be unique
:type name: str Returns:
:class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance.
:return: a distributed logger instance
:rtype: :class:`colossalai.logging.DistributedLogger`
""" """
return DistributedLogger.get_instance(name=name) return DistributedLogger.get_instance(name=name)
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']): def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None:
"""Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai".
Args: Args:

View File

@ -4,7 +4,7 @@
import colossalai import colossalai
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, List
import inspect import inspect
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
@ -40,6 +40,7 @@ class DistributedLogger:
Args: Args:
name (str): The name of the logger. name (str): The name of the logger.
Returns: Returns:
DistributedLogger: A DistributedLogger object DistributedLogger: A DistributedLogger object
""" """
@ -75,7 +76,7 @@ class DistributedLogger:
def _check_valid_logging_level(level: str): def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str): def set_level(self, level: str) -> None:
"""Set the logging level """Set the logging level
Args: Args:
@ -84,7 +85,7 @@ class DistributedLogger:
self._check_valid_logging_level(level) self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level)) self._logger.setLevel(getattr(logging, level))
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None): def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
"""Save the logs to file """Save the logs to file
Args: Args:
@ -122,7 +123,11 @@ class DistributedLogger:
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler) self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def _log(self,
level,
message: str,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks: List[int] = None) -> None:
if ranks is None: if ranks is None:
getattr(self._logger, level)(message) getattr(self._logger, level)(message)
else: else:
@ -130,53 +135,53 @@ class DistributedLogger:
if local_rank in ranks: if local_rank in ranks:
getattr(self._logger, level)(message) getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an info message. """Log an info message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks) self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks) self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a warning message. """Log a warning message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks) self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks) self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a debug message. """Log a debug message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks) self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks) self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an error message. """Log an error message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks) self._log('error', message_prefix, parallel_mode, ranks)