mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 06:53:12 +00:00
feat:llm manage
This commit is contained in:
commit
d512ddeae9
18
assets/schema/history.sql
Normal file
18
assets/schema/history.sql
Normal file
@ -0,0 +1,18 @@
|
||||
CREATE DATABASE history;
|
||||
use history;
|
||||
CREATE TABLE `chat_feed_back` (
|
||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
|
||||
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
|
||||
`score` int(1) DEFAULT NULL COMMENT '评分',
|
||||
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
|
||||
`question` longtext DEFAULT NULL COMMENT '用户问题',
|
||||
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
|
||||
`messages` longtext DEFAULT NULL COMMENT '评价详情',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
|
||||
`gmt_created` datetime DEFAULT NULL,
|
||||
`gmt_modified` datetime DEFAULT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
||||
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';
|
@ -185,6 +185,9 @@ class Config(metaclass=Singleton):
|
||||
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
|
||||
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
||||
)
|
||||
### Control whether to display the source document of knowledge on the front end.
|
||||
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
|
||||
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||
|
||||
|
@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
del self.tokenizer
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
torch_imported = False
|
||||
|
@ -11,7 +11,7 @@ from pilot.model.parameter import (
|
||||
)
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
return
|
||||
del self._embeddings_impl
|
||||
self._embeddings_impl = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict):
|
||||
"""Generate stream result, chat scene"""
|
||||
|
@ -18,6 +18,7 @@ from pilot.logs import logger
|
||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
||||
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
||||
# TODO: support internlm quantization
|
||||
model_name = model_params.model_name.lower()
|
||||
supported_models = ["llama", "baichuan", "vicuna"]
|
||||
return any(m in model_name for m in supported_models)
|
||||
|
@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
DeltaMessage,
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
||||
from pilot.configs.config import Config
|
||||
@ -85,6 +88,26 @@ def plugins_select_info():
|
||||
return plugins_infos
|
||||
|
||||
|
||||
def get_db_list_info():
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
params: dict = {}
|
||||
for item in dbs:
|
||||
comment = item["comment"]
|
||||
if comment is not None and len(comment) > 0:
|
||||
params.update({item["db_name"]: comment})
|
||||
return params
|
||||
|
||||
|
||||
def knowledge_list_info():
|
||||
"""return knowledge space list"""
|
||||
params: dict = {}
|
||||
request = KnowledgeSpaceRequest()
|
||||
spaces = knowledge_service.get_knowledge_space(request)
|
||||
for space in spaces:
|
||||
params.update({space.name: space.desc})
|
||||
return params
|
||||
|
||||
|
||||
def knowledge_list():
|
||||
"""return knowledge space list"""
|
||||
params: dict = {}
|
||||
@ -363,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat),
|
||||
stream_generator(chat, dialogue.incremental, dialogue.model_name),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
@ -401,19 +424,48 @@ async def no_stream_generator(chat):
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat):
|
||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
"""Generate streaming responses
|
||||
|
||||
Our goal is to generate an openai-compatible streaming responses.
|
||||
Currently, the incremental response is compatible, and the full response will be transformed in the future.
|
||||
|
||||
Args:
|
||||
chat (BaseChat): Chat instance.
|
||||
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
|
||||
model_name (str): The model name
|
||||
|
||||
Yields:
|
||||
_type_: streaming responses
|
||||
"""
|
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||
|
||||
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||
previous_response = ""
|
||||
async for chunk in chat.stream_call():
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
msg = msg.replace("\ufffd", "")
|
||||
if incremental:
|
||||
incremental_output = msg[len(previous_response) :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=stream_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
# TODO generate an openai-compatible streaming responses
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
previous_response = msg
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
if incremental:
|
||||
yield "data: [DONE]\n\n"
|
||||
chat.current_message.add_ai_message(msg)
|
||||
chat.current_message.add_view_message(msg)
|
||||
chat.memory.append(chat.current_message)
|
||||
|
0
pilot/openapi/api_v1/feedback/__init__.py
Normal file
0
pilot/openapi/api_v1/feedback/__init__.py
Normal file
48
pilot/openapi/api_v1/feedback/api_fb_v1.py
Normal file
48
pilot/openapi/api_v1/feedback/api_fb_v1.py
Normal file
@ -0,0 +1,48 @@
|
||||
from fastapi import APIRouter, Body, Request
|
||||
|
||||
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
from pilot.openapi.api_v1.feedback.feed_back_db import (
|
||||
ChatFeedBackDao,
|
||||
ChatFeedBackEntity,
|
||||
)
|
||||
from pilot.openapi.api_view_model import Result
|
||||
|
||||
router = APIRouter()
|
||||
chat_feed_back = ChatFeedBackDao()
|
||||
|
||||
|
||||
@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
|
||||
async def feed_back_find(conv_uid: str, conv_index: int):
|
||||
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
|
||||
if rt is not None:
|
||||
return Result.succ(
|
||||
FeedBackBody(
|
||||
conv_uid=rt.conv_uid,
|
||||
conv_index=rt.conv_index,
|
||||
question=rt.question,
|
||||
knowledge_space=rt.knowledge_space,
|
||||
score=rt.score,
|
||||
ques_type=rt.ques_type,
|
||||
messages=rt.messages,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post("/v1/feedback/commit", response_model=Result[bool])
|
||||
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
|
||||
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
|
||||
return Result.succ(True)
|
||||
|
||||
|
||||
@router.get("/v1/feedback/select", response_model=Result[dict])
|
||||
async def feed_back_select():
|
||||
return Result.succ(
|
||||
{
|
||||
"information": "信息查询",
|
||||
"work_study": "工作学习",
|
||||
"just_fun": "互动闲聊",
|
||||
"others": "其他",
|
||||
}
|
||||
)
|
84
pilot/openapi/api_v1/feedback/feed_back_db.py
Normal file
84
pilot/openapi/api_v1/feedback/feed_back_db.py
Normal file
@ -0,0 +1,84 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ChatFeedBackEntity(Base):
|
||||
__tablename__ = "chat_feed_back"
|
||||
id = Column(Integer, primary_key=True)
|
||||
conv_uid = Column(String(128))
|
||||
conv_index = Column(Integer)
|
||||
score = Column(Integer)
|
||||
ques_type = Column(String(32))
|
||||
question = Column(Text)
|
||||
knowledge_space = Column(String(128))
|
||||
messages = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
|
||||
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
|
||||
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
class ChatFeedBackDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
|
||||
|
||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||
# Todo: We need to have user information first.
|
||||
def_user_name = ""
|
||||
|
||||
session = self.Session()
|
||||
chat_feed_back = ChatFeedBackEntity(
|
||||
conv_uid=feed_back.conv_uid,
|
||||
conv_index=feed_back.conv_index,
|
||||
score=feed_back.score,
|
||||
ques_type=feed_back.ques_type,
|
||||
question=feed_back.question,
|
||||
knowledge_space=feed_back.knowledge_space,
|
||||
messages=feed_back.messages,
|
||||
user_name=def_user_name,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
|
||||
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
|
||||
.first()
|
||||
)
|
||||
if result is not None:
|
||||
result.score = feed_back.score
|
||||
result.ques_type = feed_back.ques_type
|
||||
result.question = feed_back.question
|
||||
result.knowledge_space = feed_back.knowledge_space
|
||||
result.messages = feed_back.messages
|
||||
result.user_name = def_user_name
|
||||
result.gmt_created = datetime.now()
|
||||
result.gmt_modified = datetime.now()
|
||||
else:
|
||||
session.merge(chat_feed_back)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||
session = self.Session()
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||
.filter(ChatFeedBackEntity.conv_index == conv_index)
|
||||
.first()
|
||||
)
|
||||
session.close()
|
||||
return result
|
25
pilot/openapi/api_v1/feedback/feed_back_model.py
Normal file
25
pilot/openapi/api_v1/feedback/feed_back_model.py
Normal file
@ -0,0 +1,25 @@
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
|
||||
class FeedBackBody(BaseModel):
|
||||
"""conv_uid: conversation id"""
|
||||
|
||||
conv_uid: str
|
||||
|
||||
"""conv_index: conversation index"""
|
||||
conv_index: int
|
||||
|
||||
"""question: human question"""
|
||||
question: str
|
||||
|
||||
"""knowledge_space: knowledge space"""
|
||||
knowledge_space: str
|
||||
|
||||
"""score: rating of the llm's answer"""
|
||||
score: int
|
||||
|
||||
"""ques_type: question type"""
|
||||
ques_type: str
|
||||
|
||||
"""messages: rating detail"""
|
||||
messages: str
|
@ -1,5 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Generic, Any
|
||||
from typing import TypeVar, Generic, Any, Optional, Literal, List
|
||||
import uuid
|
||||
import time
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
model_name: str = None
|
||||
|
||||
"""Used to control whether the content is returned incrementally or in full each time.
|
||||
If this parameter is not provided, the default is full return.
|
||||
"""
|
||||
incremental: bool = False
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
"""
|
||||
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
|
||||
model_name
|
||||
"""
|
||||
model_name: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
@ -55,6 +56,21 @@ class ChatKnowledge(BaseChat):
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
|
||||
async def stream_call(self):
|
||||
input_values = self.generate_input_values()
|
||||
async for output in super().stream_call():
|
||||
# Source of knowledge file
|
||||
relations = input_values.get("relations")
|
||||
if (
|
||||
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||
and type(relations) == list
|
||||
and len(relations) > 0
|
||||
and hasattr(output, "text")
|
||||
):
|
||||
output.text = output.text + "\trelations:" + ",".join(relations)
|
||||
yield output
|
||||
|
||||
def generate_input_values(self):
|
||||
if self.space_context:
|
||||
@ -69,7 +85,14 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
relations = list(
|
||||
set([os.path.basename(d.metadata.get("source")) for d in docs])
|
||||
)
|
||||
input_values = {
|
||||
"context": context,
|
||||
"question": self.current_user_input,
|
||||
"relations": relations,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
|
@ -30,6 +30,7 @@ from pilot.server.llm_manage.api import router as llm_manage_api
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||
@ -74,6 +75,7 @@ app.include_router(api_v1, prefix="/api")
|
||||
app.include_router(knowledge_router, prefix="/api")
|
||||
app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(llm_manage_api, prefix="/api")
|
||||
app.include_router(api_fb_v1, prefix="/api")
|
||||
|
||||
# app.include_router(api_v1)
|
||||
app.include_router(knowledge_router)
|
||||
|
@ -1,10 +1,22 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _clear_model_cache(device="cuda"):
|
||||
try:
|
||||
# clear torch cache
|
||||
import torch
|
||||
|
||||
_clear_torch_cache(device)
|
||||
except ImportError:
|
||||
logger.warn("Torch not installed, skip clear torch cache")
|
||||
# TODO clear other cache
|
||||
|
||||
|
||||
def _clear_torch_cache(device="cuda"):
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if device != "cpu":
|
||||
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
|
||||
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
elif torch.has_cuda:
|
||||
device_count = torch.cuda.device_count()
|
||||
for device_id in range(device_count):
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
||||
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||
with torch.cuda.device(cuda_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
||||
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||
|
Loading…
Reference in New Issue
Block a user