mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
feat(ChatKnowledge): return topk document source (KBQA) (#609)
1. Support score feedback 2. Add method to return topk document source (KBQA)
This commit is contained in:
commit
c830598c9e
18
assets/schema/history.sql
Normal file
18
assets/schema/history.sql
Normal file
@ -0,0 +1,18 @@
|
||||
CREATE DATABASE history;
|
||||
use history;
|
||||
CREATE TABLE `chat_feed_back` (
|
||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
|
||||
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
|
||||
`score` int(1) DEFAULT NULL COMMENT '评分',
|
||||
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
|
||||
`question` longtext DEFAULT NULL COMMENT '用户问题',
|
||||
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
|
||||
`messages` longtext DEFAULT NULL COMMENT '评价详情',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
|
||||
`gmt_created` datetime DEFAULT NULL,
|
||||
`gmt_modified` datetime DEFAULT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
||||
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';
|
@ -185,6 +185,9 @@ class Config(metaclass=Singleton):
|
||||
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
|
||||
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
||||
)
|
||||
### Control whether to display the source document of knowledge on the front end.
|
||||
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
|
||||
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||
|
||||
|
@ -85,6 +85,26 @@ def plugins_select_info():
|
||||
return plugins_infos
|
||||
|
||||
|
||||
def get_db_list_info():
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
params: dict = {}
|
||||
for item in dbs:
|
||||
comment = item["comment"]
|
||||
if comment is not None and len(comment) > 0:
|
||||
params.update({item["db_name"]: comment})
|
||||
return params
|
||||
|
||||
|
||||
def knowledge_list_info():
|
||||
"""return knowledge space list"""
|
||||
params: dict = {}
|
||||
request = KnowledgeSpaceRequest()
|
||||
spaces = knowledge_service.get_knowledge_space(request)
|
||||
for space in spaces:
|
||||
params.update({space.name: space.desc})
|
||||
return params
|
||||
|
||||
|
||||
def knowledge_list():
|
||||
"""return knowledge space list"""
|
||||
params: dict = {}
|
||||
|
0
pilot/openapi/api_v1/feedback/__init__.py
Normal file
0
pilot/openapi/api_v1/feedback/__init__.py
Normal file
48
pilot/openapi/api_v1/feedback/api_fb_v1.py
Normal file
48
pilot/openapi/api_v1/feedback/api_fb_v1.py
Normal file
@ -0,0 +1,48 @@
|
||||
from fastapi import APIRouter, Body, Request
|
||||
|
||||
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
from pilot.openapi.api_v1.feedback.feed_back_db import (
|
||||
ChatFeedBackDao,
|
||||
ChatFeedBackEntity,
|
||||
)
|
||||
from pilot.openapi.api_view_model import Result
|
||||
|
||||
router = APIRouter()
|
||||
chat_feed_back = ChatFeedBackDao()
|
||||
|
||||
|
||||
@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
|
||||
async def feed_back_find(conv_uid: str, conv_index: int):
|
||||
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
|
||||
if rt is not None:
|
||||
return Result.succ(
|
||||
FeedBackBody(
|
||||
conv_uid=rt.conv_uid,
|
||||
conv_index=rt.conv_index,
|
||||
question=rt.question,
|
||||
knowledge_space=rt.knowledge_space,
|
||||
score=rt.score,
|
||||
ques_type=rt.ques_type,
|
||||
messages=rt.messages,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post("/v1/feedback/commit", response_model=Result[bool])
|
||||
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
|
||||
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
|
||||
return Result.succ(True)
|
||||
|
||||
|
||||
@router.get("/v1/feedback/select", response_model=Result[dict])
|
||||
async def feed_back_select():
|
||||
return Result.succ(
|
||||
{
|
||||
"information": "信息查询",
|
||||
"work_study": "工作学习",
|
||||
"just_fun": "互动闲聊",
|
||||
"others": "其他",
|
||||
}
|
||||
)
|
84
pilot/openapi/api_v1/feedback/feed_back_db.py
Normal file
84
pilot/openapi/api_v1/feedback/feed_back_db.py
Normal file
@ -0,0 +1,84 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ChatFeedBackEntity(Base):
|
||||
__tablename__ = "chat_feed_back"
|
||||
id = Column(Integer, primary_key=True)
|
||||
conv_uid = Column(String(128))
|
||||
conv_index = Column(Integer)
|
||||
score = Column(Integer)
|
||||
ques_type = Column(String(32))
|
||||
question = Column(Text)
|
||||
knowledge_space = Column(String(128))
|
||||
messages = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
|
||||
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
|
||||
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
class ChatFeedBackDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
|
||||
|
||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||
# Todo: We need to have user information first.
|
||||
def_user_name = ""
|
||||
|
||||
session = self.Session()
|
||||
chat_feed_back = ChatFeedBackEntity(
|
||||
conv_uid=feed_back.conv_uid,
|
||||
conv_index=feed_back.conv_index,
|
||||
score=feed_back.score,
|
||||
ques_type=feed_back.ques_type,
|
||||
question=feed_back.question,
|
||||
knowledge_space=feed_back.knowledge_space,
|
||||
messages=feed_back.messages,
|
||||
user_name=def_user_name,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
|
||||
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
|
||||
.first()
|
||||
)
|
||||
if result is not None:
|
||||
result.score = feed_back.score
|
||||
result.ques_type = feed_back.ques_type
|
||||
result.question = feed_back.question
|
||||
result.knowledge_space = feed_back.knowledge_space
|
||||
result.messages = feed_back.messages
|
||||
result.user_name = def_user_name
|
||||
result.gmt_created = datetime.now()
|
||||
result.gmt_modified = datetime.now()
|
||||
else:
|
||||
session.merge(chat_feed_back)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||
session = self.Session()
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||
.filter(ChatFeedBackEntity.conv_index == conv_index)
|
||||
.first()
|
||||
)
|
||||
session.close()
|
||||
return result
|
25
pilot/openapi/api_v1/feedback/feed_back_model.py
Normal file
25
pilot/openapi/api_v1/feedback/feed_back_model.py
Normal file
@ -0,0 +1,25 @@
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
|
||||
class FeedBackBody(BaseModel):
|
||||
"""conv_uid: conversation id"""
|
||||
|
||||
conv_uid: str
|
||||
|
||||
"""conv_index: conversation index"""
|
||||
conv_index: int
|
||||
|
||||
"""question: human question"""
|
||||
question: str
|
||||
|
||||
"""knowledge_space: knowledge space"""
|
||||
knowledge_space: str
|
||||
|
||||
"""score: rating of the llm's answer"""
|
||||
score: int
|
||||
|
||||
"""ques_type: question type"""
|
||||
ques_type: str
|
||||
|
||||
"""messages: rating detail"""
|
||||
messages: str
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
@ -55,6 +56,21 @@ class ChatKnowledge(BaseChat):
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
|
||||
async def stream_call(self):
|
||||
input_values = self.generate_input_values()
|
||||
async for output in super().stream_call():
|
||||
# Source of knowledge file
|
||||
relations = input_values.get("relations")
|
||||
if (
|
||||
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||
and type(relations) == list
|
||||
and len(relations) > 0
|
||||
and hasattr(output, "text")
|
||||
):
|
||||
output.text = output.text + "\trelations:" + ",".join(relations)
|
||||
yield output
|
||||
|
||||
def generate_input_values(self):
|
||||
if self.space_context:
|
||||
@ -69,7 +85,14 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
relations = list(
|
||||
set([os.path.basename(d.metadata.get("source")) for d in docs])
|
||||
)
|
||||
input_values = {
|
||||
"context": context,
|
||||
"question": self.current_user_input,
|
||||
"relations": relations,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
|
@ -29,6 +29,7 @@ from pilot.server.prompt.api import router as prompt_router
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||
@ -72,6 +73,7 @@ app.add_middleware(
|
||||
app.include_router(api_v1, prefix="/api")
|
||||
app.include_router(knowledge_router, prefix="/api")
|
||||
app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(api_fb_v1, prefix="/api")
|
||||
|
||||
# app.include_router(api_v1)
|
||||
app.include_router(knowledge_router)
|
||||
|
Loading…
Reference in New Issue
Block a user