[legacy] move communication and nn to legacy and refactor logger (#4671)

* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
This commit is contained in:
Hongxin Liu
2023-09-11 16:24:28 +08:00
committed by GitHub
parent 536397cc95
commit 554aa9592e
170 changed files with 781 additions and 758 deletions

View File

@@ -6,8 +6,7 @@ import logging
from pathlib import Path
from typing import List, Union
import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.distributed as dist
class DistributedLogger:
@@ -63,6 +62,7 @@ class DistributedLogger:
self._logger.propagate = False
DistributedLogger.__instances[name] = self
self.rank = dist.get_rank() if dist.is_initialized() else 0
@staticmethod
def __get_call_info():
@@ -109,16 +109,10 @@ class DistributedLogger:
# create log directory
path.mkdir(parents=True, exist_ok=True)
# set the default file name if path is a directory
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
rank = 0
else:
rank = colossalai.core.global_context.get_global_rank()
if suffix is not None:
log_file_name = f'rank_{rank}_{suffix}.log'
log_file_name = f'rank_{self.rank}_{suffix}.log'
else:
log_file_name = f'rank_{rank}.log'
log_file_name = f'rank_{self.rank}.log'
path = path.joinpath(log_file_name)
# add file handler
@@ -128,19 +122,14 @@ class DistributedLogger:
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self,
level,
message: str,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks: List[int] = None) -> None:
def _log(self, level, message: str, ranks: List[int] = None) -> None:
if ranks is None:
getattr(self._logger, level)(message)
else:
local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
if local_rank in ranks:
if self.rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
def info(self, message: str, ranks: List[int] = None) -> None:
"""Log an info message.
Args:
@@ -150,10 +139,10 @@ class DistributedLogger:
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks)
self._log('info', message_prefix, ranks)
self._log('info', message, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
def warning(self, message: str, ranks: List[int] = None) -> None:
"""Log a warning message.
Args:
@@ -163,10 +152,10 @@ class DistributedLogger:
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks)
self._log('warning', message_prefix, ranks)
self._log('warning', message, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
def debug(self, message: str, ranks: List[int] = None) -> None:
"""Log a debug message.
Args:
@@ -176,10 +165,10 @@ class DistributedLogger:
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks)
self._log('debug', message_prefix, ranks)
self._log('debug', message, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
def error(self, message: str, ranks: List[int] = None) -> None:
"""Log an error message.
Args:
@@ -189,5 +178,5 @@ class DistributedLogger:
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks)
self._log('error', message, parallel_mode, ranks)
self._log('error', message_prefix, ranks)
self._log('error', message, ranks)