feat:llm manage

This commit is contained in:
aries_ckt 2023-09-22 00:30:34 +08:00
commit d512ddeae9
14 changed files with 310 additions and 17 deletions

18
assets/schema/history.sql Normal file
View File

@ -0,0 +1,18 @@
CREATE DATABASE history;
use history;
CREATE TABLE `chat_feed_back` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
`score` int(1) DEFAULT NULL COMMENT '评分',
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
`question` longtext DEFAULT NULL COMMENT '用户问题',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
`messages` longtext DEFAULT NULL COMMENT '评价详情',
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
KEY `idx_conv` (`conv_uid`,`conv_index`)
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';

View File

@ -185,6 +185,9 @@ class Config(metaclass=Singleton):
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
)
### Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

View File

@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False

View File

@ -11,7 +11,7 @@ from pilot.model.parameter import (
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""

View File

@ -18,6 +18,7 @@ from pilot.logs import logger
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
# TODO: support internlm quantization
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)

View File

@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
ConversationVo,
MessageVo,
ChatSceneVo,
ChatCompletionResponseStreamChoice,
DeltaMessage,
ChatCompletionStreamResponse,
)
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config
@ -85,6 +88,26 @@ def plugins_select_info():
return plugins_infos
def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
if comment is not None and len(comment) > 0:
params.update({item["db_name"]: comment})
return params
def knowledge_list_info():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.desc})
return params
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
@ -363,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
else:
return StreamingResponse(
stream_generator(chat),
stream_generator(chat, dialogue.incremental, dialogue.model_name),
headers=headers,
media_type="text/plain",
)
@ -401,19 +424,48 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"
async def stream_generator(chat):
async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses
Our goal is to generate an openai-compatible streaming responses.
Currently, the incremental response is compatible, and the full response will be transformed in the future.
Args:
chat (BaseChat): Chat instance.
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
model_name (str): The model name
Yields:
_type_: streaming responses
"""
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in chat.stream_call():
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
msg = msg.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)

View File

@ -0,0 +1,48 @@
from fastapi import APIRouter, Body, Request
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
from pilot.openapi.api_v1.feedback.feed_back_db import (
ChatFeedBackDao,
ChatFeedBackEntity,
)
from pilot.openapi.api_view_model import Result
router = APIRouter()
chat_feed_back = ChatFeedBackDao()
@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
async def feed_back_find(conv_uid: str, conv_index: int):
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
if rt is not None:
return Result.succ(
FeedBackBody(
conv_uid=rt.conv_uid,
conv_index=rt.conv_index,
question=rt.question,
knowledge_space=rt.knowledge_space,
score=rt.score,
ques_type=rt.ques_type,
messages=rt.messages,
)
)
else:
return Result.succ(None)
@router.post("/v1/feedback/commit", response_model=Result[bool])
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
return Result.succ(True)
@router.get("/v1/feedback/select", response_model=Result[dict])
async def feed_back_select():
return Result.succ(
{
"information": "信息查询",
"work_study": "工作学习",
"just_fun": "互动闲聊",
"others": "其他",
}
)

View File

@ -0,0 +1,84 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
Base = declarative_base()
class ChatFeedBackEntity(Base):
__tablename__ = "chat_feed_back"
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)
score = Column(Integer)
ques_type = Column(String(32))
question = Column(Text)
knowledge_space = Column(String(128))
messages = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return (
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)
class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
def_user_name = ""
session = self.Session()
chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index,
score=feed_back.score,
ques_type=feed_back.ques_type,
question=feed_back.question,
knowledge_space=feed_back.knowledge_space,
messages=feed_back.messages,
user_name=def_user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
.first()
)
if result is not None:
result.score = feed_back.score
result.ques_type = feed_back.ques_type
result.question = feed_back.question
result.knowledge_space = feed_back.knowledge_space
result.messages = feed_back.messages
result.user_name = def_user_name
result.gmt_created = datetime.now()
result.gmt_modified = datetime.now()
else:
session.merge(chat_feed_back)
session.commit()
session.close()
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.Session()
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
.filter(ChatFeedBackEntity.conv_index == conv_index)
.first()
)
session.close()
return result

View File

@ -0,0 +1,25 @@
from pydantic.main import BaseModel
class FeedBackBody(BaseModel):
"""conv_uid: conversation id"""
conv_uid: str
"""conv_index: conversation index"""
conv_index: int
"""question: human question"""
question: str
"""knowledge_space: knowledge space"""
knowledge_space: str
"""score: rating of the llm's answer"""
score: int
"""ques_type: question type"""
ques_type: str
"""messages: rating detail"""
messages: str

View File

@ -1,5 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
from typing import TypeVar, Generic, Any, Optional, Literal, List
import uuid
import time
T = TypeVar("T")
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
"""
model_name: str = None
"""Used to control whether the content is returned incrementally or in full each time.
If this parameter is not provided, the default is full return.
"""
incremental: bool = False
class MessageVo(BaseModel):
"""
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
model_name
"""
model_name: str
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@ -1,3 +1,4 @@
import os
from typing import Dict
from pilot.scene.base_chat import BaseChat
@ -55,6 +56,21 @@ class ChatKnowledge(BaseChat):
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
self.prompt_template.template_is_strict = False
async def stream_call(self):
input_values = self.generate_input_values()
async for output in super().stream_call():
# Source of knowledge file
relations = input_values.get("relations")
if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and type(relations) == list
and len(relations) > 0
and hasattr(output, "text")
):
output.text = output.text + "\trelations:" + ",".join(relations)
yield output
def generate_input_values(self):
if self.space_context:
@ -69,7 +85,14 @@ class ChatKnowledge(BaseChat):
)
context = [d.page_content for d in docs]
context = context[: self.max_token]
input_values = {"context": context, "question": self.current_user_input}
relations = list(
set([os.path.basename(d.metadata.get("source")) for d in docs])
)
input_values = {
"context": context,
"question": self.current_user_input,
"relations": relations,
}
return input_values
@property

View File

@ -30,6 +30,7 @@ from pilot.server.llm_manage.api import router as llm_manage_api
from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
@ -74,6 +75,7 @@ app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(llm_manage_api, prefix="/api")
app.include_router(api_fb_v1, prefix="/api")
# app.include_router(api_v1)
app.include_router(knowledge_router)

View File

@ -1,10 +1,22 @@
import logging
logger = logging.getLogger(__name__)
def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch
_clear_torch_cache(device)
except ImportError:
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache
def _clear_torch_cache(device="cuda"):
import gc
import torch
import gc
gc.collect()
if device != "cpu":
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")
logger.info("No cuda or mps, not support clear torch cache yet")