mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -9,14 +9,18 @@ import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
import shortuuid
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastchat.constants import ErrorCode
|
||||
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, model_to_dict, model_to_json
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.core.schema.api import (
|
||||
APIChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
@@ -25,20 +29,18 @@ from fastchat.protocol.openai_api_protocol import (
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
ErrorCode,
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||
from dbgpt.model.cluster.registry import ModelRegistry
|
||||
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from dbgpt.util.fastapi import create_app
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
@@ -88,7 +90,7 @@ def create_error_response(code: int, message: str) -> JSONResponse:
|
||||
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
|
||||
"""
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=message, code=code).dict(), status_code=400
|
||||
model_to_dict(ErrorResponse(message=message, code=code)), status_code=400
|
||||
)
|
||||
|
||||
|
||||
@@ -266,7 +268,8 @@ class APIServer(BaseComponent):
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
@@ -297,10 +300,15 @@ class APIServer(BaseComponent):
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {json_data}\n\n"
|
||||
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
|
||||
json_data = model_to_json(
|
||||
finish_chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
yield f"data: {json_data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_generate(
|
||||
@@ -335,8 +343,8 @@ class APIServer(BaseComponent):
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.parse_obj(model_output.usage)
|
||||
for usage_key, usage_value in task_usage.dict().items():
|
||||
task_usage = UsageInfo.model_validate(model_output.usage)
|
||||
for usage_key, usage_value in model_to_dict(task_usage).items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
@@ -442,8 +450,9 @@ async def create_embeddings(
|
||||
}
|
||||
for i, emb in enumerate(embeddings)
|
||||
]
|
||||
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
|
||||
exclude_none=True
|
||||
return model_to_dict(
|
||||
EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()),
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -492,7 +501,7 @@ def initialize_apiserver(
|
||||
embedded_mod = True
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
|
@@ -22,9 +22,10 @@ from dbgpt.model.cluster.apiserver.api import (
|
||||
)
|
||||
from dbgpt.model.cluster.tests.conftest import _new_cluster
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
from dbgpt.util.fastapi import create_app
|
||||
from dbgpt.util.openai_utils import chat_completion, chat_completion_stream
|
||||
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -69,7 +70,7 @@ async def client(request, system_app: SystemApp):
|
||||
async def test_get_all_models(client: AsyncClient):
|
||||
res = await client.get("/api/v1/models")
|
||||
res.status_code == 200
|
||||
model_lists = ModelList.parse_obj(res.json())
|
||||
model_lists = ModelList.model_validate(res.json())
|
||||
print(f"model list json: {res.json()}")
|
||||
assert model_lists.object == "list"
|
||||
assert len(model_lists.data) == 2
|
||||
|
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.model.base import ModelInstance
|
||||
@@ -10,6 +10,7 @@ from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from dbgpt.model.parameter import ModelControllerParameters
|
||||
from dbgpt.util.api_utils import _api_remote as api_remote
|
||||
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
|
||||
from dbgpt.util.fastapi import create_app
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
from dbgpt.util.utils import setup_http_service_logging, setup_logging
|
||||
|
||||
@@ -152,7 +153,7 @@ def initialize_controller(
|
||||
import uvicorn
|
||||
|
||||
setup_http_service_logging()
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
@@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Iterator
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
@@ -28,6 +28,7 @@ from dbgpt.model.cluster.registry import ModelRegistry
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
|
||||
from dbgpt.model.utils.llm_utils import list_supported_models
|
||||
from dbgpt.util.fastapi import create_app, register_event_handler
|
||||
from dbgpt.util.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
@@ -829,7 +830,7 @@ def _setup_fastapi(
|
||||
worker_params: ModelWorkerParameters, app=None, ignore_exception: bool = False
|
||||
):
|
||||
if not app:
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
setup_http_service_logging()
|
||||
|
||||
if worker_params.standalone:
|
||||
@@ -850,7 +851,6 @@ def _setup_fastapi(
|
||||
initialize_controller(app=app)
|
||||
app.include_router(controller_router, prefix="/api")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
async def start_worker_manager():
|
||||
try:
|
||||
@@ -865,10 +865,11 @@ def _setup_fastapi(
|
||||
# the fastapi app (registered to the controller)
|
||||
asyncio.create_task(start_worker_manager())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def startup_event():
|
||||
async def shutdown_event():
|
||||
await worker_manager.stop(ignore_exception=ignore_exception)
|
||||
|
||||
register_event_handler(app, "startup", startup_event)
|
||||
register_event_handler(app, "shutdown", shutdown_event)
|
||||
return app
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user