Files
DB-GPT/dbgpt/util/utils.py

174 lines
4.9 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import logging
import logging.handlers
from typing import Any, List
import os
import asyncio
from dbgpt.configs.model_config import LOGDIR
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
handler = None
def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging_level(logging_level=None, logger_name: str = None):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
logging_level = logging.getLevelName(logging_level.upper())
if logger_name:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
else:
logging.basicConfig(level=logging_level, encoding="utf-8")
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
if not logging_level:
logging_level = _get_logging_level()
logger = _build_logger(logger_name, logging_level, logger_filename)
try:
import coloredlogs
color_level = logging_level if logging_level else "INFO"
coloredlogs.install(level=color_level, logger=logger)
except ImportError:
pass
def get_gpu_memory(max_gpus=None):
import torch
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
if max_gpus is None
else min(max_gpus, torch.cuda.device_count())
)
for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024**3)
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
setup_logging_level(logging_level=logging_level)
logging.getLogger().handlers[0].setFormatter(formatter)
# Add a file handler for all loggers
if handler is None and logger_filename:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="D", utc=True
)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
# Get logger
logger = logging.getLogger(logger_name)
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
return logger
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None
try:
loop = asyncio.get_event_loop()
assert loop is not None
return loop
except RuntimeError as e:
if not "no running event loop" in str(e) and not "no current event loop" in str(
e
):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
return asyncio.get_event_loop_policy().new_event_loop()
def logging_str_to_uvicorn_level(log_level_str):
level_str_mapping = {
"CRITICAL": "critical",
"ERROR": "error",
"WARNING": "warning",
"INFO": "info",
"DEBUG": "debug",
"NOTSET": "info",
}
return level_str_mapping.get(log_level_str.upper(), "info")
class EndpointFilter(logging.Filter):
"""Disable access log on certain endpoint
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
"""
def __init__(
self,
path: str,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self._path = path
def filter(self, record: logging.LogRecord) -> bool:
return record.getMessage().find(self._path) == -1
def setup_http_service_logging(exclude_paths: List[str] = None):
"""Setup http service logging
Now just disable some logs
Args:
exclude_paths (List[str]): The paths to disable log
"""
if not exclude_paths:
# Not show heartbeat log
exclude_paths = ["/api/controller/heartbeat"]
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
for path in exclude_paths:
uvicorn_logger.addFilter(EndpointFilter(path=path))
httpx_logger = logging.getLogger("httpx")
if httpx_logger:
httpx_logger.setLevel(logging.WARNING)