diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index 0dd87543c..120f473ea 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -1,10 +1,8 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- - import asyncio import logging import logging.handlers import os +import sys from typing import Any, List, Optional, cast from dbgpt.configs.model_config import LOGDIR @@ -46,10 +44,11 @@ def setup_logging( logger_name: str, logging_level: Optional[str] = None, logger_filename: Optional[str] = None, + redirect_stdio: bool = False, ): if not logging_level: logging_level = _get_logging_level() - logger = _build_logger(logger_name, logging_level, logger_filename) + logger = _build_logger(logger_name, logging_level, logger_filename, redirect_stdio) try: import coloredlogs @@ -68,7 +67,6 @@ def get_gpu_memory(max_gpus=None): if max_gpus is None else min(max_gpus, torch.cuda.device_count()) ) - for gpu_id in range(num_gpus): with torch.cuda.device(gpu_id): device = torch.cuda.current_device() @@ -84,6 +82,7 @@ def _build_logger( logger_name, logging_level: Optional[str] = None, logger_filename: Optional[str] = None, + redirect_stdio: bool = False, ): global handler @@ -106,13 +105,38 @@ def _build_logger( ) handler.setFormatter(formatter) + # Ensure the handler level is set correctly + if logging_level is not None: + handler.setLevel(logging_level) + logging.getLogger().addHandler(handler) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) - # Get logger + item.propagate = True + logging.getLogger(name).debug(f"Added handler to logger: {name}") + else: + logging.getLogger(name).debug(f"Skipping non-logger: {name}") + + if redirect_stdio: + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(formatter) + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(stdout_handler) + root_logger.addHandler(stderr_handler) + logging.getLogger().debug("Added stdout and stderr handlers to root logger") logger = logging.getLogger(logger_name) + setup_logging_level(logging_level=logging_level, logger_name=logger_name) + # Debugging to print all handlers + logging.getLogger(logger_name).debug( + f"Logger {logger_name} handlers: {logger.handlers}" + ) + logging.getLogger(logger_name).debug(f"Global handler: {handler}") + return logger