[doc] improved docstring in the logging module (#861)

This commit is contained in:
Frank Lee
2022-04-25 13:42:00 +08:00
committed by GitHub
parent 8004c8e938
commit b862d89d00
2 changed files with 23 additions and 20 deletions

View File

@@ -4,7 +4,7 @@
import colossalai
import logging
from pathlib import Path
from typing import Union
from typing import Union, List
import inspect
from colossalai.context.parallel_mode import ParallelMode
@@ -40,6 +40,7 @@ class DistributedLogger:
Args:
name (str): The name of the logger.
Returns:
DistributedLogger: A DistributedLogger object
"""
@@ -75,7 +76,7 @@ class DistributedLogger:
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str):
def set_level(self, level: str) -> None:
"""Set the logging level
Args:
@@ -84,7 +85,7 @@ class DistributedLogger:
self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level))
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
"""Save the logs to file
Args:
@@ -122,7 +123,11 @@ class DistributedLogger:
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def _log(self,
level,
message: str,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks: List[int] = None) -> None:
if ranks is None:
getattr(self._logger, level)(message)
else:
@@ -130,53 +135,53 @@ class DistributedLogger:
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an info message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a warning message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a debug message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an error message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks)