feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -10,6 +10,7 @@ from fastapi import APIRouter, Body, Depends, File, UploadFile
from fastapi.responses import StreamingResponse
from dbgpt._private.config import Config
from dbgpt._private.pydantic import model_to_dict, model_to_json
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.openapi.api_view_model import (
@@ -147,7 +148,7 @@ def get_executor() -> Executor:
).create()
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
@router.get("/v1/chat/db/list", response_model=Result[List[DBConfig]])
async def db_connect_list():
return Result.succ(CFG.local_db_manager.get_db_list())
@@ -189,7 +190,7 @@ async def db_summary(db_name: str, db_type: str):
return Result.succ(True)
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
@router.get("/v1/chat/db/support/type", response_model=Result[List[DbTypeInfo]])
async def db_support_types():
support_types = CFG.local_db_manager.get_all_completed_types()
db_type_infos = []
@@ -223,7 +224,7 @@ async def dialogue_scenes():
return Result.succ(scene_vos)
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
@router.post("/v1/chat/mode/params/list", response_model=Result[dict | list])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
if ChatScene.ChatWithDbQA.value() == chat_mode:
return Result.succ(get_db_list())
@@ -378,7 +379,9 @@ async def chat_completions(
)
else:
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
"get_chat_instance",
span_type=SpanType.CHAT,
metadata=model_to_dict(dialogue),
):
chat: BaseChat = await get_chat_instance(dialogue)
@@ -458,7 +461,10 @@ async def stream_generator(chat, incremental: bool, model_name: str):
chunk = ChatCompletionStreamResponse(
id=chat.chat_session_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
json_chunk = model_to_json(
chunk, exclude_unset=True, ensure_ascii=False
)
yield f"data: {json_chunk}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")

View File

@@ -43,7 +43,7 @@ def get_edit_service() -> EditorService:
return EditorService.get_instance(CFG.SYSTEM_APP)
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
@router.get("/v1/editor/db/tables", response_model=Result[DataNode])
async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
@@ -70,15 +70,15 @@ async def get_editor_tables(
return Result.succ(db_node)
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
@router.get("/v1/editor/sql/rounds", response_model=Result[List[ChatDbRounds]])
async def get_editor_sql_rounds(
con_uid: str, editor_service: EditorService = Depends(get_edit_service)
):
logger.info("get_editor_sql_rounds:{con_uid}")
logger.info(f"get_editor_sql_rounds:{ con_uid}")
return Result.succ(editor_service.get_editor_sql_rounds(con_uid))
@router.get("/v1/editor/sql", response_model=Result[dict])
@router.get("/v1/editor/sql", response_model=Result[dict | list])
async def get_editor_sql(
con_uid: str, round: int, editor_service: EditorService = Depends(get_edit_service)
):
@@ -107,7 +107,7 @@ async def editor_sql_run(run_param: dict = Body()):
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000,
run_cost=int((end_time - start_time) / 1000),
colunms=colunms,
values=sql_result,
)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from dbgpt._private.config import Config
from dbgpt.app.openapi.api_view_model import Result
@@ -70,7 +70,7 @@ class EditorService(BaseComponent):
def get_editor_sql_by_round(
self, conv_uid: str, round_index: int
) -> Optional[Dict]:
) -> Optional[Union[List, Dict]]:
storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
messages_by_round = _split_messages_by_round(storage_conv.messages)
for one_round_message in messages_by_round:
@@ -184,7 +184,7 @@ class EditorService(BaseComponent):
return Result.failed(msg="Can't Find Chart Detail Info!")
def _parse_pure_dict(res_str: str) -> Dict:
def _parse_pure_dict(res_str: str) -> Union[Dict, List]:
output_parser = BaseOutputParser()
context = output_parser.parse_prompt_response(res_str)
return json.loads(context)

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List, Optional
from dbgpt._private.pydantic import BaseModel
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
@@ -9,15 +9,15 @@ class DataNode(BaseModel):
key: str
type: str = ""
default_value: str = None
default_value: Optional[Any] = None
can_null: str = "YES"
comment: str = None
comment: Optional[str] = None
children: List = []
class SqlRunData(BaseModel):
result_info: str
run_cost: str
run_cost: int
colunms: List[str]
values: List