mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 21:08:59 +00:00
feat(ChatKnowledge): Support Financial Report Analysis (#1702)
Co-authored-by: hzh97 <2976151305@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: licunxing <864255598@qq.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import Executor
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import aiofiles
|
||||
from fastapi import APIRouter, Body, Depends, File, UploadFile
|
||||
@@ -21,8 +21,11 @@ from dbgpt.app.openapi.api_view_model import (
|
||||
)
|
||||
from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.configs import TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel import BaseOperator, CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.util.chat_util import safe_chat_stream_with_dag_task
|
||||
from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
@@ -127,6 +130,11 @@ def get_worker_manager() -> WorkerManager:
|
||||
return worker_manager
|
||||
|
||||
|
||||
def get_dag_manager() -> DAGManager:
|
||||
"""Get the global default DAGManager"""
|
||||
return DAGManager.get_instance(CFG.SYSTEM_APP)
|
||||
|
||||
|
||||
def get_chat_flow() -> FlowService:
|
||||
"""Get Chat Flow Service."""
|
||||
return FlowService.get_instance(CFG.SYSTEM_APP)
|
||||
@@ -252,7 +260,7 @@ async def params_load(
|
||||
sys_code: Optional[str] = None,
|
||||
doc_file: UploadFile = File(...),
|
||||
):
|
||||
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
# Save the uploaded file
|
||||
@@ -335,7 +343,7 @@ async def chat_completions(
|
||||
dialogue: ConversationVo = Body(),
|
||||
flow_service: FlowService = Depends(get_chat_flow),
|
||||
):
|
||||
print(
|
||||
logger.info(
|
||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||
)
|
||||
headers = {
|
||||
@@ -344,6 +352,7 @@ async def chat_completions(
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
domain_type = _parse_domain_type(dialogue)
|
||||
if dialogue.chat_mode == ChatScene.ChatAgent.value():
|
||||
return StreamingResponse(
|
||||
multi_agents.app_agent_chat(
|
||||
@@ -378,12 +387,20 @@ async def chat_completions(
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif domain_type is not None:
|
||||
return StreamingResponse(
|
||||
chat_with_domain_flow(dialogue, domain_type),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance",
|
||||
span_type=SpanType.CHAT,
|
||||
metadata=model_to_dict(dialogue),
|
||||
):
|
||||
|
||||
chat: BaseChat = await get_chat_instance(dialogue)
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
@@ -484,3 +501,61 @@ def message2Vo(message: dict, order, model_name) -> MessageVo:
|
||||
order=order,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]:
|
||||
if dialogue.chat_mode == ChatScene.ChatKnowledge.value():
|
||||
# Supported in the knowledge chat
|
||||
space_name = dialogue.select_param
|
||||
spaces = knowledge_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest(name=space_name)
|
||||
)
|
||||
if len(spaces) == 0:
|
||||
return Result.failed(
|
||||
code="E000X", msg=f"Knowledge space {space_name} not found"
|
||||
)
|
||||
if spaces[0].domain_type:
|
||||
return spaces[0].domain_type
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def chat_with_domain_flow(dialogue: ConversationVo, domain_type: str):
|
||||
"""Chat with domain flow"""
|
||||
dag_manager = get_dag_manager()
|
||||
dags = dag_manager.get_dags_by_tag(TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE, domain_type)
|
||||
if not dags or not dags[0].leaf_nodes:
|
||||
raise ValueError(f"Cant find the DAG for domain type {domain_type}")
|
||||
|
||||
end_task = cast(BaseOperator, dags[0].leaf_nodes[0])
|
||||
|
||||
space = dialogue.select_param
|
||||
connector_manager = CFG.local_db_manager
|
||||
# TODO: Some flow maybe not connector
|
||||
db_list = [item["db_name"] for item in connector_manager.get_db_list()]
|
||||
db_names = [item for item in db_list if space in item]
|
||||
if len(db_names) == 0:
|
||||
raise ValueError(f"fin repost dbname {space}_fin_report not found.")
|
||||
flow_ctx = {"space": space, "db_name": db_names[0]}
|
||||
request = CommonLLMHttpRequestBody(
|
||||
model=dialogue.model_name,
|
||||
messages=dialogue.user_input,
|
||||
stream=True,
|
||||
extra=flow_ctx,
|
||||
conv_uid=dialogue.conv_uid,
|
||||
span_id=root_tracer.get_current_span_id(),
|
||||
chat_mode=dialogue.chat_mode,
|
||||
chat_param=dialogue.select_param,
|
||||
user_name=dialogue.user_name,
|
||||
sys_code=dialogue.sys_code,
|
||||
incremental=dialogue.incremental,
|
||||
)
|
||||
async for output in safe_chat_stream_with_dag_task(end_task, request, False):
|
||||
text = output.text
|
||||
if text:
|
||||
text = text.replace("\n", "\\n")
|
||||
if output.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{text}\n\n"
|
||||
break
|
||||
else:
|
||||
yield f"data:{text}\n\n"
|
||||
|
Reference in New Issue
Block a user