mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[doc] improved docstring in the logging module (#861)
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import colossalai
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
import inspect
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
@@ -40,6 +40,7 @@ class DistributedLogger:
|
||||
|
||||
Args:
|
||||
name (str): The name of the logger.
|
||||
|
||||
Returns:
|
||||
DistributedLogger: A DistributedLogger object
|
||||
"""
|
||||
@@ -75,7 +76,7 @@ class DistributedLogger:
|
||||
def _check_valid_logging_level(level: str):
|
||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
||||
|
||||
def set_level(self, level: str):
|
||||
def set_level(self, level: str) -> None:
|
||||
"""Set the logging level
|
||||
|
||||
Args:
|
||||
@@ -84,7 +85,7 @@ class DistributedLogger:
|
||||
self._check_valid_logging_level(level)
|
||||
self._logger.setLevel(getattr(logging, level))
|
||||
|
||||
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
|
||||
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
|
||||
"""Save the logs to file
|
||||
|
||||
Args:
|
||||
@@ -122,7 +123,11 @@ class DistributedLogger:
|
||||
file_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(file_handler)
|
||||
|
||||
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def _log(self,
|
||||
level,
|
||||
message: str,
|
||||
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
|
||||
ranks: List[int] = None) -> None:
|
||||
if ranks is None:
|
||||
getattr(self._logger, level)(message)
|
||||
else:
|
||||
@@ -130,53 +135,53 @@ class DistributedLogger:
|
||||
if local_rank in ranks:
|
||||
getattr(self._logger, level)(message)
|
||||
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log an info message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('info', message_prefix, parallel_mode, ranks)
|
||||
self._log('info', message, parallel_mode, ranks)
|
||||
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log a warning message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('warning', message_prefix, parallel_mode, ranks)
|
||||
self._log('warning', message, parallel_mode, ranks)
|
||||
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log a debug message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('debug', message_prefix, parallel_mode, ranks)
|
||||
self._log('debug', message, parallel_mode, ranks)
|
||||
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log an error message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('error', message_prefix, parallel_mode, ranks)
|
||||
|
Reference in New Issue
Block a user