DB-GPT/dbgpt/util/utils.py
2024-07-03 00:06:13 +08:00

207 lines
6.3 KiB
Python

import asyncio
import logging
import logging.handlers
import os
import sys
from typing import Any, List, Optional, cast
from dbgpt.configs.model_config import LOGDIR
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
handler = None
def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging_level(
logging_level: Optional[str] = None, logger_name: Optional[str] = None
):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
logging_level = logging.getLevelName(logging_level.upper())
if logger_name:
logger = logging.getLogger(logger_name)
logger.setLevel(cast(str, logging_level))
else:
logging.basicConfig(level=logging_level, encoding="utf-8")
def setup_logging(
logger_name: str,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
redirect_stdio: bool = False,
):
if not logging_level:
logging_level = _get_logging_level()
logger = _build_logger(logger_name, logging_level, logger_filename, redirect_stdio)
try:
import coloredlogs
color_level = logging_level if logging_level else "INFO"
coloredlogs.install(level=color_level, logger=logger)
except ImportError:
pass
def get_gpu_memory(max_gpus=None):
import torch
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
if max_gpus is None
else min(max_gpus, torch.cuda.device_count())
)
for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024**3)
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory
def _build_logger(
logger_name,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
redirect_stdio: bool = False,
):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
setup_logging_level(logging_level=logging_level)
logging.getLogger().handlers[0].setFormatter(formatter)
# Add a file handler for all loggers
if handler is None and logger_filename:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="D", utc=True
)
handler.setFormatter(formatter)
# Ensure the handler level is set correctly
if logging_level is not None:
handler.setLevel(logging_level)
logging.getLogger().addHandler(handler)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
item.propagate = True
logging.getLogger(name).debug(f"Added handler to logger: {name}")
else:
logging.getLogger(name).debug(f"Skipping non-logger: {name}")
if redirect_stdio:
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(formatter)
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.addHandler(stdout_handler)
root_logger.addHandler(stderr_handler)
logging.getLogger().debug("Added stdout and stderr handlers to root logger")
logger = logging.getLogger(logger_name)
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
# Debugging to print all handlers
logging.getLogger(logger_name).debug(
f"Logger {logger_name} handlers: {logger.handlers}"
)
logging.getLogger(logger_name).debug(f"Global handler: {handler}")
return logger
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None
try:
loop = asyncio.get_event_loop()
assert loop is not None
return cast(asyncio.BaseEventLoop, loop)
except RuntimeError as e:
if not "no running event loop" in str(e) and not "no current event loop" in str(
e
):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
return cast(asyncio.BaseEventLoop, asyncio.get_event_loop_policy().new_event_loop())
def logging_str_to_uvicorn_level(log_level_str):
level_str_mapping = {
"CRITICAL": "critical",
"ERROR": "error",
"WARNING": "warning",
"INFO": "info",
"DEBUG": "debug",
"NOTSET": "info",
}
return level_str_mapping.get(log_level_str.upper(), "info")
class EndpointFilter(logging.Filter):
"""Disable access log on certain endpoint
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
"""
def __init__(
self,
path: str,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self._path = path
def filter(self, record: logging.LogRecord) -> bool:
return record.getMessage().find(self._path) == -1
def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
"""Setup http service logging
Now just disable some logs
Args:
exclude_paths (List[str]): The paths to disable log
"""
if not exclude_paths:
# Not show heartbeat log
exclude_paths = ["/api/controller/heartbeat", "/api/health"]
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
for path in exclude_paths:
uvicorn_logger.addFilter(EndpointFilter(path=path))
httpx_logger = logging.getLogger("httpx")
if httpx_logger:
httpx_logger.setLevel(logging.WARNING)