mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 13:27:46 +00:00
207 lines
6.3 KiB
Python
207 lines
6.3 KiB
Python
import asyncio
|
|
import logging
|
|
import logging.handlers
|
|
import os
|
|
import sys
|
|
from typing import Any, List, Optional, cast
|
|
|
|
from dbgpt.configs.model_config import LOGDIR
|
|
|
|
try:
|
|
from termcolor import colored
|
|
except ImportError:
|
|
|
|
def colored(x, *args, **kwargs):
|
|
return x
|
|
|
|
|
|
server_error_msg = (
|
|
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
|
)
|
|
|
|
handler = None
|
|
|
|
|
|
def _get_logging_level() -> str:
|
|
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
|
|
|
|
|
def setup_logging_level(
|
|
logging_level: Optional[str] = None, logger_name: Optional[str] = None
|
|
):
|
|
if not logging_level:
|
|
logging_level = _get_logging_level()
|
|
if type(logging_level) is str:
|
|
logging_level = logging.getLevelName(logging_level.upper())
|
|
if logger_name:
|
|
logger = logging.getLogger(logger_name)
|
|
logger.setLevel(cast(str, logging_level))
|
|
else:
|
|
logging.basicConfig(level=logging_level, encoding="utf-8")
|
|
|
|
|
|
def setup_logging(
|
|
logger_name: str,
|
|
logging_level: Optional[str] = None,
|
|
logger_filename: Optional[str] = None,
|
|
redirect_stdio: bool = False,
|
|
):
|
|
if not logging_level:
|
|
logging_level = _get_logging_level()
|
|
logger = _build_logger(logger_name, logging_level, logger_filename, redirect_stdio)
|
|
try:
|
|
import coloredlogs
|
|
|
|
color_level = logging_level if logging_level else "INFO"
|
|
coloredlogs.install(level=color_level, logger=logger)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def get_gpu_memory(max_gpus=None):
|
|
import torch
|
|
|
|
gpu_memory = []
|
|
num_gpus = (
|
|
torch.cuda.device_count()
|
|
if max_gpus is None
|
|
else min(max_gpus, torch.cuda.device_count())
|
|
)
|
|
for gpu_id in range(num_gpus):
|
|
with torch.cuda.device(gpu_id):
|
|
device = torch.cuda.current_device()
|
|
gpu_properties = torch.cuda.get_device_properties(device)
|
|
total_memory = gpu_properties.total_memory / (1024**3)
|
|
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
|
available_memory = total_memory - allocated_memory
|
|
gpu_memory.append(available_memory)
|
|
return gpu_memory
|
|
|
|
|
|
def _build_logger(
|
|
logger_name,
|
|
logging_level: Optional[str] = None,
|
|
logger_filename: Optional[str] = None,
|
|
redirect_stdio: bool = False,
|
|
):
|
|
global handler
|
|
|
|
formatter = logging.Formatter(
|
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
# Set the format of root handlers
|
|
if not logging.getLogger().handlers:
|
|
setup_logging_level(logging_level=logging_level)
|
|
logging.getLogger().handlers[0].setFormatter(formatter)
|
|
|
|
# Add a file handler for all loggers
|
|
if handler is None and logger_filename:
|
|
os.makedirs(LOGDIR, exist_ok=True)
|
|
filename = os.path.join(LOGDIR, logger_filename)
|
|
handler = logging.handlers.TimedRotatingFileHandler(
|
|
filename, when="D", utc=True
|
|
)
|
|
handler.setFormatter(formatter)
|
|
|
|
# Ensure the handler level is set correctly
|
|
if logging_level is not None:
|
|
handler.setLevel(logging_level)
|
|
logging.getLogger().addHandler(handler)
|
|
for name, item in logging.root.manager.loggerDict.items():
|
|
if isinstance(item, logging.Logger):
|
|
item.addHandler(handler)
|
|
item.propagate = True
|
|
logging.getLogger(name).debug(f"Added handler to logger: {name}")
|
|
else:
|
|
logging.getLogger(name).debug(f"Skipping non-logger: {name}")
|
|
|
|
if redirect_stdio:
|
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
stdout_handler.setFormatter(formatter)
|
|
stderr_handler = logging.StreamHandler(sys.stderr)
|
|
stderr_handler.setFormatter(formatter)
|
|
|
|
root_logger = logging.getLogger()
|
|
root_logger.addHandler(stdout_handler)
|
|
root_logger.addHandler(stderr_handler)
|
|
logging.getLogger().debug("Added stdout and stderr handlers to root logger")
|
|
logger = logging.getLogger(logger_name)
|
|
|
|
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
|
|
|
|
# Debugging to print all handlers
|
|
logging.getLogger(logger_name).debug(
|
|
f"Logger {logger_name} handlers: {logger.handlers}"
|
|
)
|
|
logging.getLogger(logger_name).debug(f"Global handler: {handler}")
|
|
|
|
return logger
|
|
|
|
|
|
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
|
loop = None
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
assert loop is not None
|
|
return cast(asyncio.BaseEventLoop, loop)
|
|
except RuntimeError as e:
|
|
if not "no running event loop" in str(e) and not "no current event loop" in str(
|
|
e
|
|
):
|
|
raise e
|
|
logging.warning("Cant not get running event loop, create new event loop now")
|
|
return cast(asyncio.BaseEventLoop, asyncio.get_event_loop_policy().new_event_loop())
|
|
|
|
|
|
def logging_str_to_uvicorn_level(log_level_str):
|
|
level_str_mapping = {
|
|
"CRITICAL": "critical",
|
|
"ERROR": "error",
|
|
"WARNING": "warning",
|
|
"INFO": "info",
|
|
"DEBUG": "debug",
|
|
"NOTSET": "info",
|
|
}
|
|
return level_str_mapping.get(log_level_str.upper(), "info")
|
|
|
|
|
|
class EndpointFilter(logging.Filter):
|
|
"""Disable access log on certain endpoint
|
|
|
|
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self._path = path
|
|
|
|
def filter(self, record: logging.LogRecord) -> bool:
|
|
return record.getMessage().find(self._path) == -1
|
|
|
|
|
|
def setup_http_service_logging(exclude_paths: Optional[List[str]] = None):
|
|
"""Setup http service logging
|
|
|
|
Now just disable some logs
|
|
|
|
Args:
|
|
exclude_paths (List[str]): The paths to disable log
|
|
"""
|
|
if not exclude_paths:
|
|
# Not show heartbeat log
|
|
exclude_paths = ["/api/controller/heartbeat", "/api/health"]
|
|
uvicorn_logger = logging.getLogger("uvicorn.access")
|
|
if uvicorn_logger:
|
|
for path in exclude_paths:
|
|
uvicorn_logger.addFilter(EndpointFilter(path=path))
|
|
httpx_logger = logging.getLogger("httpx")
|
|
if httpx_logger:
|
|
httpx_logger.setLevel(logging.WARNING)
|