fix(core): Fix fschat and alembic log conflict (#919)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng 2023-12-12 09:14:40 +08:00 committed by GitHub
parent cbba50ab1b
commit f1f4f4adda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 101 deletions

View File

@ -13,7 +13,7 @@ from fastapi import APIRouter, FastAPI
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseSettings from pydantic import BaseSettings
@ -30,10 +30,7 @@ from fastchat.protocol.openai_api_protocol import (
ModelPermission, ModelPermission,
UsageInfo, UsageInfo,
) )
from fastchat.protocol.api_protocol import ( from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
APIChatCompletionRequest,
)
from fastchat.serve.openai_api_server import create_error_response, check_requests
from fastchat.constants import ErrorCode from fastchat.constants import ErrorCode
from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.component import BaseComponent, ComponentType, SystemApp
@ -85,6 +82,68 @@ async def check_api_key(
return None return None
def create_error_response(code: int, message: str) -> JSONResponse:
"""Copy from fastchat.serve.openai_api_server.check_requests
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=400
)
def check_requests(request) -> Optional[JSONResponse]:
"""Copy from fastchat.serve.openai_api_server.create_error_response
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
return None
class APIServer(BaseComponent): class APIServer(BaseComponent):
name = ComponentType.MODEL_API_SERVER name = ComponentType.MODEL_API_SERVER

View File

@ -1,8 +1,5 @@
from .utils import ( from .utils import (
get_gpu_memory, get_gpu_memory,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,
server_error_msg, server_error_msg,
get_or_create_event_loop, get_or_create_event_loop,
) )

View File

@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
return logger return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop: def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None loop = None
try: try:

View File

@ -79,38 +79,3 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# ruff.type = exec # ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff # ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME # ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -1,5 +1,3 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config
from sqlalchemy import pool from sqlalchemy import pool
@ -11,10 +9,6 @@ from dbgpt.storage.metadata.meta_data import Base, engine
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support