feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -9,14 +9,18 @@ import logging
from typing import Any, Dict, Generator, List, Optional
import shortuuid
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from fastchat.constants import ErrorCode
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
from fastchat.protocol.openai_api_protocol import (
from dbgpt._private.pydantic import BaseModel, model_to_dict, model_to_json
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelOutput
from dbgpt.core.interface.message import ModelMessage
from dbgpt.core.schema.api import (
APIChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
@@ -25,20 +29,18 @@ from fastchat.protocol.openai_api_protocol import (
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ErrorCode,
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
from dbgpt._private.pydantic import BaseModel
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelOutput
from dbgpt.core.interface.message import ModelMessage
from dbgpt.model.base import ModelInstance
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
from dbgpt.util.fastapi import create_app
from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.utils import setup_logging
@@ -88,7 +90,7 @@ def create_error_response(code: int, message: str) -> JSONResponse:
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=400
model_to_dict(ErrorResponse(message=message, code=code)), status_code=400
)
@@ -266,7 +268,8 @@ class APIServer(BaseComponent):
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
previous_text = ""
async for model_output in worker_manager.generate_stream(params):
@@ -297,10 +300,15 @@ class APIServer(BaseComponent):
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False)
yield f"data: {json_data}\n\n"
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
json_data = model_to_json(
finish_chunk, exclude_unset=True, ensure_ascii=False
)
yield f"data: {json_data}\n\n"
yield "data: [DONE]\n\n"
async def chat_completion_generate(
@@ -335,8 +343,8 @@ class APIServer(BaseComponent):
)
)
if model_output.usage:
task_usage = UsageInfo.parse_obj(model_output.usage)
for usage_key, usage_value in task_usage.dict().items():
task_usage = UsageInfo.model_validate(model_output.usage)
for usage_key, usage_value in model_to_dict(task_usage).items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
@@ -442,8 +450,9 @@ async def create_embeddings(
}
for i, emb in enumerate(embeddings)
]
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
exclude_none=True
return model_to_dict(
EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()),
exclude_none=True,
)
@@ -492,7 +501,7 @@ def initialize_apiserver(
embedded_mod = True
if not app:
embedded_mod = False
app = FastAPI()
app = create_app()
if not system_app:
system_app = SystemApp(app)