mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 12:59:43 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -10,6 +10,7 @@ from fastapi import APIRouter, Body, Depends, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import model_to_dict, model_to_json
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.app.openapi.api_view_model import (
|
||||
@@ -147,7 +148,7 @@ def get_executor() -> Executor:
|
||||
).create()
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||
@router.get("/v1/chat/db/list", response_model=Result[List[DBConfig]])
|
||||
async def db_connect_list():
|
||||
return Result.succ(CFG.local_db_manager.get_db_list())
|
||||
|
||||
@@ -189,7 +190,7 @@ async def db_summary(db_name: str, db_type: str):
|
||||
return Result.succ(True)
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[List[DbTypeInfo]])
|
||||
async def db_support_types():
|
||||
support_types = CFG.local_db_manager.get_all_completed_types()
|
||||
db_type_infos = []
|
||||
@@ -223,7 +224,7 @@ async def dialogue_scenes():
|
||||
return Result.succ(scene_vos)
|
||||
|
||||
|
||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict | list])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
||||
if ChatScene.ChatWithDbQA.value() == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
@@ -378,7 +379,9 @@ async def chat_completions(
|
||||
)
|
||||
else:
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
||||
"get_chat_instance",
|
||||
span_type=SpanType.CHAT,
|
||||
metadata=model_to_dict(dialogue),
|
||||
):
|
||||
chat: BaseChat = await get_chat_instance(dialogue)
|
||||
|
||||
@@ -458,7 +461,10 @@ async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=chat.chat_session_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
json_chunk = model_to_json(
|
||||
chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
yield f"data: {json_chunk}\n\n"
|
||||
else:
|
||||
# TODO generate an openai-compatible streaming responses
|
||||
msg = msg.replace("\n", "\\n")
|
||||
|
@@ -43,7 +43,7 @@ def get_edit_service() -> EditorService:
|
||||
return EditorService.get_instance(CFG.SYSTEM_APP)
|
||||
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DataNode])
|
||||
async def get_editor_tables(
|
||||
db_name: str, page_index: int, page_size: int, search_str: str = ""
|
||||
):
|
||||
@@ -70,15 +70,15 @@ async def get_editor_tables(
|
||||
return Result.succ(db_node)
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[List[ChatDbRounds]])
|
||||
async def get_editor_sql_rounds(
|
||||
con_uid: str, editor_service: EditorService = Depends(get_edit_service)
|
||||
):
|
||||
logger.info("get_editor_sql_rounds:{con_uid}")
|
||||
logger.info(f"get_editor_sql_rounds:{ con_uid}")
|
||||
return Result.succ(editor_service.get_editor_sql_rounds(con_uid))
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict | list])
|
||||
async def get_editor_sql(
|
||||
con_uid: str, round: int, editor_service: EditorService = Depends(get_edit_service)
|
||||
):
|
||||
@@ -107,7 +107,7 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
run_cost=int((end_time - start_time) / 1000),
|
||||
colunms=colunms,
|
||||
values=sql_result,
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
@@ -70,7 +70,7 @@ class EditorService(BaseComponent):
|
||||
|
||||
def get_editor_sql_by_round(
|
||||
self, conv_uid: str, round_index: int
|
||||
) -> Optional[Dict]:
|
||||
) -> Optional[Union[List, Dict]]:
|
||||
storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
|
||||
messages_by_round = _split_messages_by_round(storage_conv.messages)
|
||||
for one_round_message in messages_by_round:
|
||||
@@ -184,7 +184,7 @@ class EditorService(BaseComponent):
|
||||
return Result.failed(msg="Can't Find Chart Detail Info!")
|
||||
|
||||
|
||||
def _parse_pure_dict(res_str: str) -> Dict:
|
||||
def _parse_pure_dict(res_str: str) -> Union[Dict, List]:
|
||||
output_parser = BaseOutputParser()
|
||||
context = output_parser.parse_prompt_response(res_str)
|
||||
return json.loads(context)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||
@@ -9,15 +9,15 @@ class DataNode(BaseModel):
|
||||
key: str
|
||||
|
||||
type: str = ""
|
||||
default_value: str = None
|
||||
default_value: Optional[Any] = None
|
||||
can_null: str = "YES"
|
||||
comment: str = None
|
||||
comment: Optional[str] = None
|
||||
children: List = []
|
||||
|
||||
|
||||
class SqlRunData(BaseModel):
|
||||
result_info: str
|
||||
run_cost: str
|
||||
run_cost: int
|
||||
colunms: List[str]
|
||||
values: List
|
||||
|
||||
|
@@ -8,6 +8,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict, model_to_json
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import (
|
||||
CHAT_FACTORY,
|
||||
__new_conversation,
|
||||
@@ -130,7 +131,9 @@ async def chat_completions(
|
||||
or request.chat_mode == ChatMode.CHAT_DATA.value
|
||||
):
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict()
|
||||
"get_chat_instance",
|
||||
span_type=SpanType.CHAT,
|
||||
metadata=model_to_dict(request),
|
||||
):
|
||||
chat: BaseChat = await get_chat_instance(request)
|
||||
|
||||
@@ -243,21 +246,22 @@ async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None):
|
||||
model=request.model,
|
||||
created=int(time.time()),
|
||||
)
|
||||
content = (
|
||||
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
json_content = model_to_json(
|
||||
chunk, exclude_unset=True, ensure_ascii=False
|
||||
)
|
||||
content = f"data: {json_content}\n\n"
|
||||
yield content
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def chat_flow_wrapper(request: ChatCompletionRequestBody):
|
||||
flow_service = get_chat_flow()
|
||||
flow_req = CommonLLMHttpRequestBody(**request.dict())
|
||||
flow_req = CommonLLMHttpRequestBody(**model_to_dict(request))
|
||||
flow_uid = request.chat_param
|
||||
output = await flow_service.safe_chat_flow(flow_uid, flow_req)
|
||||
if not output.success:
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=output.text, code=output.error_code).dict(),
|
||||
model_to_dict(ErrorResponse(message=output.text, code=output.error_code)),
|
||||
status_code=400,
|
||||
)
|
||||
else:
|
||||
@@ -282,7 +286,7 @@ async def chat_flow_stream_wrapper(
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
"""
|
||||
flow_service = get_chat_flow()
|
||||
flow_req = CommonLLMHttpRequestBody(**request.dict())
|
||||
flow_req = CommonLLMHttpRequestBody(**model_to_dict(request))
|
||||
flow_uid = request.chat_param
|
||||
|
||||
async for output in flow_service.chat_stream_openai(flow_uid, flow_req):
|
||||
|
@@ -1,17 +1,15 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Generic, List, Literal, Optional, TypeVar
|
||||
from typing import Any, Dict, Generic, Optional, TypeVar
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Result(Generic[T], BaseModel):
|
||||
class Result(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
err_code: str = None
|
||||
err_msg: str = None
|
||||
data: T = None
|
||||
err_code: Optional[str] = None
|
||||
err_msg: Optional[str] = None
|
||||
data: Optional[T] = None
|
||||
|
||||
@classmethod
|
||||
def succ(cls, data: T):
|
||||
@@ -21,6 +19,9 @@ class Result(Generic[T], BaseModel):
|
||||
def failed(cls, code: str = "E000X", msg=None):
|
||||
return Result(success=False, err_code=code, err_msg=msg, data=None)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class ChatSceneVo(BaseModel):
|
||||
chat_scene: str = Field(..., description="chat_scene")
|
||||
@@ -31,6 +32,8 @@ class ChatSceneVo(BaseModel):
|
||||
|
||||
|
||||
class ConversationVo(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
"""
|
||||
dialogue_uid
|
||||
"""
|
||||
@@ -43,7 +46,7 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
user
|
||||
"""
|
||||
user_name: str = None
|
||||
user_name: Optional[str] = Field(None, description="user name")
|
||||
"""
|
||||
the scene of chat
|
||||
"""
|
||||
@@ -52,21 +55,23 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
chat scene select param
|
||||
"""
|
||||
select_param: str = None
|
||||
select_param: Optional[str] = Field(None, description="chat scene select param")
|
||||
"""
|
||||
llm model name
|
||||
"""
|
||||
model_name: str = None
|
||||
model_name: Optional[str] = Field(None, description="llm model name")
|
||||
|
||||
"""Used to control whether the content is returned incrementally or in full each time.
|
||||
If this parameter is not provided, the default is full return.
|
||||
"""
|
||||
incremental: bool = False
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
"""
|
||||
role that sends out the current message
|
||||
"""
|
||||
@@ -83,7 +88,9 @@ class MessageVo(BaseModel):
|
||||
"""
|
||||
time the current message was sent
|
||||
"""
|
||||
time_stamp: Any = None
|
||||
time_stamp: Optional[Any] = Field(
|
||||
None, description="time the current message was sent"
|
||||
)
|
||||
|
||||
"""
|
||||
model_name
|
||||
|
@@ -11,4 +11,4 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
loc = ".".join(list(map(str, error.get("loc"))))
|
||||
message += loc + ":" + error.get("msg") + ";"
|
||||
res = Result.failed(code="E0001", msg=message)
|
||||
return JSONResponse(status_code=400, content=res.dict())
|
||||
return JSONResponse(status_code=400, content=res.to_dict())
|
||||
|
Reference in New Issue
Block a user