mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
fix(core): Fix fschat and alembic log conflict (#919)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
parent
cbba50ab1b
commit
f1f4f4adda
@ -13,7 +13,7 @@ from fastapi import APIRouter, FastAPI
|
|||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
@ -30,10 +30,7 @@ from fastchat.protocol.openai_api_protocol import (
|
|||||||
ModelPermission,
|
ModelPermission,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from fastchat.protocol.api_protocol import (
|
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
|
||||||
APIChatCompletionRequest,
|
|
||||||
)
|
|
||||||
from fastchat.serve.openai_api_server import create_error_response, check_requests
|
|
||||||
from fastchat.constants import ErrorCode
|
from fastchat.constants import ErrorCode
|
||||||
|
|
||||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||||
@ -85,6 +82,68 @@ async def check_api_key(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_response(code: int, message: str) -> JSONResponse:
|
||||||
|
"""Copy from fastchat.serve.openai_api_server.check_requests
|
||||||
|
|
||||||
|
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
||||||
|
"""
|
||||||
|
return JSONResponse(
|
||||||
|
ErrorResponse(message=message, code=code).dict(), status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_requests(request) -> Optional[JSONResponse]:
|
||||||
|
"""Copy from fastchat.serve.openai_api_server.create_error_response
|
||||||
|
|
||||||
|
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
||||||
|
"""
|
||||||
|
# Check all params
|
||||||
|
if request.max_tokens is not None and request.max_tokens <= 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
|
||||||
|
)
|
||||||
|
if request.n is not None and request.n <= 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.n} is less than the minimum of 1 - 'n'",
|
||||||
|
)
|
||||||
|
if request.temperature is not None and request.temperature < 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
|
||||||
|
)
|
||||||
|
if request.temperature is not None and request.temperature > 2:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
|
||||||
|
)
|
||||||
|
if request.top_p is not None and request.top_p < 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
|
||||||
|
)
|
||||||
|
if request.top_p is not None and request.top_p > 1:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
|
||||||
|
)
|
||||||
|
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
|
||||||
|
)
|
||||||
|
if request.stop is not None and (
|
||||||
|
not isinstance(request.stop, str) and not isinstance(request.stop, list)
|
||||||
|
):
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.stop} is not valid under any of the given schemas - 'stop'",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class APIServer(BaseComponent):
|
class APIServer(BaseComponent):
|
||||||
name = ComponentType.MODEL_API_SERVER
|
name = ComponentType.MODEL_API_SERVER
|
||||||
|
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
get_gpu_memory,
|
get_gpu_memory,
|
||||||
StreamToLogger,
|
|
||||||
disable_torch_init,
|
|
||||||
pretty_print_semaphore,
|
|
||||||
server_error_msg,
|
server_error_msg,
|
||||||
get_or_create_event_loop,
|
get_or_create_event_loop,
|
||||||
)
|
)
|
||||||
|
@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
class StreamToLogger(object):
|
|
||||||
"""
|
|
||||||
Fake file-like stream object that redirects writes to a logger instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, logger, log_level=logging.INFO):
|
|
||||||
self.terminal = sys.stdout
|
|
||||||
self.logger = logger
|
|
||||||
self.log_level = log_level
|
|
||||||
self.linebuf = ""
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.terminal, attr)
|
|
||||||
|
|
||||||
def write(self, buf):
|
|
||||||
temp_linebuf = self.linebuf + buf
|
|
||||||
self.linebuf = ""
|
|
||||||
for line in temp_linebuf.splitlines(True):
|
|
||||||
# From the io.TextIOWrapper docs:
|
|
||||||
# On output, if newline is None, any '\n' characters written
|
|
||||||
# are translated to the system default line separator.
|
|
||||||
# By default sys.stdout.write() expects '\n' newlines and then
|
|
||||||
# translates them so this is still cross platform.
|
|
||||||
if line[-1] == "\n":
|
|
||||||
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
|
||||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
|
||||||
else:
|
|
||||||
self.linebuf += line
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
if self.linebuf != "":
|
|
||||||
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
|
||||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
|
||||||
self.linebuf = ""
|
|
||||||
|
|
||||||
|
|
||||||
def disable_torch_init():
|
|
||||||
"""
|
|
||||||
Disable the redundant torch default initialization to accelerate model creation.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
|
||||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
|
||||||
|
|
||||||
|
|
||||||
def pretty_print_semaphore(semaphore):
|
|
||||||
if semaphore is None:
|
|
||||||
return "None"
|
|
||||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||||
loop = None
|
loop = None
|
||||||
try:
|
try:
|
||||||
|
@ -79,38 +79,3 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
|||||||
# ruff.type = exec
|
# ruff.type = exec
|
||||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
# Logging configuration
|
|
||||||
[loggers]
|
|
||||||
keys = root,sqlalchemy,alembic
|
|
||||||
|
|
||||||
[handlers]
|
|
||||||
keys = console
|
|
||||||
|
|
||||||
[formatters]
|
|
||||||
keys = generic
|
|
||||||
|
|
||||||
[logger_root]
|
|
||||||
level = WARN
|
|
||||||
handlers = console
|
|
||||||
qualname =
|
|
||||||
|
|
||||||
[logger_sqlalchemy]
|
|
||||||
level = WARN
|
|
||||||
handlers =
|
|
||||||
qualname = sqlalchemy.engine
|
|
||||||
|
|
||||||
[logger_alembic]
|
|
||||||
level = INFO
|
|
||||||
handlers =
|
|
||||||
qualname = alembic
|
|
||||||
|
|
||||||
[handler_console]
|
|
||||||
class = StreamHandler
|
|
||||||
args = (sys.stderr,)
|
|
||||||
level = NOTSET
|
|
||||||
formatter = generic
|
|
||||||
|
|
||||||
[formatter_generic]
|
|
||||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
|
||||||
datefmt = %H:%M:%S
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from logging.config import fileConfig
|
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
|
|
||||||
@ -11,10 +9,6 @@ from dbgpt.storage.metadata.meta_data import Base, engine
|
|||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
|
||||||
# This line sets up loggers basically.
|
|
||||||
if config.config_file_name is not None:
|
|
||||||
fileConfig(config.config_file_name)
|
|
||||||
|
|
||||||
# add your model's MetaData object here
|
# add your model's MetaData object here
|
||||||
# for 'autogenerate' support
|
# for 'autogenerate' support
|
||||||
|
Loading…
Reference in New Issue
Block a user