mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 07:34:07 +00:00
fix(core): Fix fschat and alembic log conflict (#919)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
parent
cbba50ab1b
commit
f1f4f4adda
@ -13,7 +13,7 @@ from fastapi import APIRouter, FastAPI
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from pydantic import BaseSettings
|
||||
@ -30,10 +30,7 @@ from fastchat.protocol.openai_api_protocol import (
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastchat.protocol.api_protocol import (
|
||||
APIChatCompletionRequest,
|
||||
)
|
||||
from fastchat.serve.openai_api_server import create_error_response, check_requests
|
||||
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
|
||||
from fastchat.constants import ErrorCode
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
@ -85,6 +82,68 @@ async def check_api_key(
|
||||
return None
|
||||
|
||||
|
||||
def create_error_response(code: int, message: str) -> JSONResponse:
|
||||
"""Copy from fastchat.serve.openai_api_server.check_requests
|
||||
|
||||
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
||||
"""
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=message, code=code).dict(), status_code=400
|
||||
)
|
||||
|
||||
|
||||
def check_requests(request) -> Optional[JSONResponse]:
|
||||
"""Copy from fastchat.serve.openai_api_server.create_error_response
|
||||
|
||||
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
||||
"""
|
||||
# Check all params
|
||||
if request.max_tokens is not None and request.max_tokens <= 0:
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
|
||||
)
|
||||
if request.n is not None and request.n <= 0:
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.n} is less than the minimum of 1 - 'n'",
|
||||
)
|
||||
if request.temperature is not None and request.temperature < 0:
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
|
||||
)
|
||||
if request.temperature is not None and request.temperature > 2:
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
|
||||
)
|
||||
if request.top_p is not None and request.top_p < 0:
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
|
||||
)
|
||||
if request.top_p is not None and request.top_p > 1:
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
|
||||
)
|
||||
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
|
||||
)
|
||||
if request.stop is not None and (
|
||||
not isinstance(request.stop, str) and not isinstance(request.stop, list)
|
||||
):
|
||||
return create_error_response(
|
||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||
f"{request.stop} is not valid under any of the given schemas - 'stop'",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class APIServer(BaseComponent):
|
||||
name = ComponentType.MODEL_API_SERVER
|
||||
|
||||
|
@ -1,8 +1,5 @@
|
||||
from .utils import (
|
||||
get_gpu_memory,
|
||||
StreamToLogger,
|
||||
disable_torch_init,
|
||||
pretty_print_semaphore,
|
||||
server_error_msg,
|
||||
get_or_create_event_loop,
|
||||
)
|
||||
|
@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||
return logger
|
||||
|
||||
|
||||
class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ""
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == "\n":
|
||||
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != "":
|
||||
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
self.linebuf = ""
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
||||
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
loop = None
|
||||
try:
|
||||
|
@ -79,38 +79,3 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
|
@ -1,5 +1,3 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
@ -11,10 +9,6 @@ from dbgpt.storage.metadata.meta_data import Base, engine
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
|
Loading…
Reference in New Issue
Block a user