fix(core): Fix fschat and alembic log conflict (#919)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-12 09:14:40 +08:00
committed by GitHub
parent cbba50ab1b
commit f1f4f4adda
5 changed files with 64 additions and 101 deletions

View File

@@ -1,8 +1,5 @@
from .utils import (
get_gpu_memory,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,
server_error_msg,
get_or_create_event_loop,
)

View File

@@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None
try: