fix(core): Fix fschat and alembic log conflict (#919)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng 2023-12-12 09:14:40 +08:00 committed by GitHub
parent cbba50ab1b
commit f1f4f4adda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 101 deletions

View File

@ -13,7 +13,7 @@ from fastapi import APIRouter, FastAPI
from fastapi import Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseSettings
@ -30,10 +30,7 @@ from fastchat.protocol.openai_api_protocol import (
ModelPermission,
UsageInfo,
)
from fastchat.protocol.api_protocol import (
APIChatCompletionRequest,
)
from fastchat.serve.openai_api_server import create_error_response, check_requests
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
from fastchat.constants import ErrorCode
from dbgpt.component import BaseComponent, ComponentType, SystemApp
@ -85,6 +82,68 @@ async def check_api_key(
return None
def create_error_response(code: int, message: str) -> JSONResponse:
"""Copy from fastchat.serve.openai_api_server.check_requests
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=400
)
def check_requests(request) -> Optional[JSONResponse]:
"""Copy from fastchat.serve.openai_api_server.create_error_response
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
return None
class APIServer(BaseComponent):
name = ComponentType.MODEL_API_SERVER

View File

@ -1,8 +1,5 @@
from .utils import (
get_gpu_memory,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,
server_error_msg,
get_or_create_event_loop,
)

View File

@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None
try:

View File

@ -79,38 +79,3 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -1,5 +1,3 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
@ -11,10 +9,6 @@ from dbgpt.storage.metadata.meta_data import Base, engine
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support