feat(ChatKnowledge): Support Financial Report Analysis (#1702)

Co-authored-by: hzh97 <2976151305@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: licunxing <864255598@qq.com>
This commit is contained in:
Aries-ckt
2024-07-26 13:40:54 +08:00
committed by GitHub
parent 22e0680a6a
commit 167d972093
160 changed files with 89339 additions and 795 deletions

View File

@@ -3,7 +3,7 @@ import logging
import os
import uuid
from concurrent.futures import Executor
from typing import List, Optional
from typing import List, Optional, cast
import aiofiles
from fastapi import APIRouter, Body, Depends, File, UploadFile
@@ -21,8 +21,11 @@ from dbgpt.app.openapi.api_view_model import (
)
from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
from dbgpt.component import ComponentType
from dbgpt.configs import TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt.core.awel import CommonLLMHttpRequestBody
from dbgpt.core.awel import BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.util.chat_util import safe_chat_stream_with_dag_task
from dbgpt.core.schema.api import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
@@ -127,6 +130,11 @@ def get_worker_manager() -> WorkerManager:
return worker_manager
def get_dag_manager() -> DAGManager:
"""Get the global default DAGManager"""
return DAGManager.get_instance(CFG.SYSTEM_APP)
def get_chat_flow() -> FlowService:
"""Get Chat Flow Service."""
return FlowService.get_instance(CFG.SYSTEM_APP)
@@ -252,7 +260,7 @@ async def params_load(
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
# Save the uploaded file
@@ -335,7 +343,7 @@ async def chat_completions(
dialogue: ConversationVo = Body(),
flow_service: FlowService = Depends(get_chat_flow),
):
print(
logger.info(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
headers = {
@@ -344,6 +352,7 @@ async def chat_completions(
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
domain_type = _parse_domain_type(dialogue)
if dialogue.chat_mode == ChatScene.ChatAgent.value():
return StreamingResponse(
multi_agents.app_agent_chat(
@@ -378,12 +387,20 @@ async def chat_completions(
headers=headers,
media_type="text/event-stream",
)
elif domain_type is not None:
return StreamingResponse(
chat_with_domain_flow(dialogue, domain_type),
headers=headers,
media_type="text/event-stream",
)
else:
with root_tracer.start_span(
"get_chat_instance",
span_type=SpanType.CHAT,
metadata=model_to_dict(dialogue),
):
chat: BaseChat = await get_chat_instance(dialogue)
if not chat.prompt_template.stream_out:
@@ -484,3 +501,61 @@ def message2Vo(message: dict, order, model_name) -> MessageVo:
order=order,
model_name=model_name,
)
def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]:
if dialogue.chat_mode == ChatScene.ChatKnowledge.value():
# Supported in the knowledge chat
space_name = dialogue.select_param
spaces = knowledge_service.get_knowledge_space(
KnowledgeSpaceRequest(name=space_name)
)
if len(spaces) == 0:
return Result.failed(
code="E000X", msg=f"Knowledge space {space_name} not found"
)
if spaces[0].domain_type:
return spaces[0].domain_type
else:
return None
async def chat_with_domain_flow(dialogue: ConversationVo, domain_type: str):
"""Chat with domain flow"""
dag_manager = get_dag_manager()
dags = dag_manager.get_dags_by_tag(TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE, domain_type)
if not dags or not dags[0].leaf_nodes:
raise ValueError(f"Cant find the DAG for domain type {domain_type}")
end_task = cast(BaseOperator, dags[0].leaf_nodes[0])
space = dialogue.select_param
connector_manager = CFG.local_db_manager
# TODO: Some flow maybe not connector
db_list = [item["db_name"] for item in connector_manager.get_db_list()]
db_names = [item for item in db_list if space in item]
if len(db_names) == 0:
raise ValueError(f"fin repost dbname {space}_fin_report not found.")
flow_ctx = {"space": space, "db_name": db_names[0]}
request = CommonLLMHttpRequestBody(
model=dialogue.model_name,
messages=dialogue.user_input,
stream=True,
extra=flow_ctx,
conv_uid=dialogue.conv_uid,
span_id=root_tracer.get_current_span_id(),
chat_mode=dialogue.chat_mode,
chat_param=dialogue.select_param,
user_name=dialogue.user_name,
sys_code=dialogue.sys_code,
incremental=dialogue.incremental,
)
async for output in safe_chat_stream_with_dag_task(end_task, request, False):
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"