feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -9,14 +9,18 @@ import logging
from typing import Any, Dict, Generator, List, Optional
import shortuuid
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from fastchat.constants import ErrorCode
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
from fastchat.protocol.openai_api_protocol import (
from dbgpt._private.pydantic import BaseModel, model_to_dict, model_to_json
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelOutput
from dbgpt.core.interface.message import ModelMessage
from dbgpt.core.schema.api import (
APIChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
@@ -25,20 +29,18 @@ from fastchat.protocol.openai_api_protocol import (
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ErrorCode,
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
from dbgpt._private.pydantic import BaseModel
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelOutput
from dbgpt.core.interface.message import ModelMessage
from dbgpt.model.base import ModelInstance
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
from dbgpt.util.fastapi import create_app
from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.utils import setup_logging
@@ -88,7 +90,7 @@ def create_error_response(code: int, message: str) -> JSONResponse:
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=400
model_to_dict(ErrorResponse(message=message, code=code)), status_code=400
)
@@ -266,7 +268,8 @@ class APIServer(BaseComponent):
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
previous_text = ""
async for model_output in worker_manager.generate_stream(params):
@@ -297,10 +300,15 @@ class APIServer(BaseComponent):
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
json_data = model_to_json(
finish_chunk, exclude_unset=True, ensure_ascii=False
)
yield f"data: {json_data}\n\n"
yield "data: [DONE]\n\n"
async def chat_completion_generate(
@@ -335,8 +343,8 @@ class APIServer(BaseComponent):
)
)
if model_output.usage:
task_usage = UsageInfo.parse_obj(model_output.usage)
for usage_key, usage_value in task_usage.dict().items():
task_usage = UsageInfo.model_validate(model_output.usage)
for usage_key, usage_value in model_to_dict(task_usage).items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
@@ -442,8 +450,9 @@ async def create_embeddings(
}
for i, emb in enumerate(embeddings)
]
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
exclude_none=True
return model_to_dict(
EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()),
exclude_none=True,
)
@@ -492,7 +501,7 @@ def initialize_apiserver(
embedded_mod = True
if not app:
embedded_mod = False
app = FastAPI()
app = create_app()
if not system_app:
system_app = SystemApp(app)

View File

@@ -22,9 +22,10 @@ from dbgpt.model.cluster.apiserver.api import (
)
from dbgpt.model.cluster.tests.conftest import _new_cluster
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
from dbgpt.util.fastapi import create_app
from dbgpt.util.openai_utils import chat_completion, chat_completion_stream
app = FastAPI()
app = create_app()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -69,7 +70,7 @@ async def client(request, system_app: SystemApp):
async def test_get_all_models(client: AsyncClient):
res = await client.get("/api/v1/models")
res.status_code == 200
model_lists = ModelList.parse_obj(res.json())
model_lists = ModelList.model_validate(res.json())
print(f"model list json: {res.json()}")
assert model_lists.object == "list"
assert len(model_lists.data) == 2

View File

@@ -2,7 +2,7 @@ import logging
from abc import ABC, abstractmethod
from typing import List
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
@@ -10,6 +10,7 @@ from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from dbgpt.model.parameter import ModelControllerParameters
from dbgpt.util.api_utils import _api_remote as api_remote
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
from dbgpt.util.fastapi import create_app
from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.utils import setup_http_service_logging, setup_logging
@@ -152,7 +153,7 @@ def initialize_controller(
import uvicorn
setup_http_service_logging()
app = FastAPI()
app = create_app()
app.include_router(router, prefix="/api", tags=["Model"])
uvicorn.run(app, host=host, port=port, log_level="info")

View File

@@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Iterator
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
@@ -28,6 +28,7 @@ from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
from dbgpt.model.utils.llm_utils import list_supported_models
from dbgpt.util.fastapi import create_app, register_event_handler
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
@@ -829,7 +830,7 @@ def _setup_fastapi(
worker_params: ModelWorkerParameters, app=None, ignore_exception: bool = False
):
if not app:
app = FastAPI()
app = create_app()
setup_http_service_logging()
if worker_params.standalone:
@@ -850,7 +851,6 @@ def _setup_fastapi(
initialize_controller(app=app)
app.include_router(controller_router, prefix="/api")
@app.on_event("startup")
async def startup_event():
async def start_worker_manager():
try:
@@ -865,10 +865,11 @@ def _setup_fastapi(
# the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager())
@app.on_event("shutdown")
async def startup_event():
async def shutdown_event():
await worker_manager.stop(ignore_exception=ignore_exception)
register_event_handler(app, "startup", startup_event)
register_event_handler(app, "shutdown", shutdown_event)
return app