mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-17 15:58:25 +00:00
feat:llm manage
This commit is contained in:
commit
d512ddeae9
18
assets/schema/history.sql
Normal file
18
assets/schema/history.sql
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
CREATE DATABASE history;
|
||||||
|
use history;
|
||||||
|
CREATE TABLE `chat_feed_back` (
|
||||||
|
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||||
|
`conv_uid` varchar(128) DEFAULT NULL COMMENT '会话id',
|
||||||
|
`conv_index` int(4) DEFAULT NULL COMMENT '第几轮会话',
|
||||||
|
`score` int(1) DEFAULT NULL COMMENT '评分',
|
||||||
|
`ques_type` varchar(32) DEFAULT NULL COMMENT '用户问题类别',
|
||||||
|
`question` longtext DEFAULT NULL COMMENT '用户问题',
|
||||||
|
`knowledge_space` varchar(128) DEFAULT NULL COMMENT '知识库',
|
||||||
|
`messages` longtext DEFAULT NULL COMMENT '评价详情',
|
||||||
|
`user_name` varchar(128) DEFAULT NULL COMMENT '评价人',
|
||||||
|
`gmt_created` datetime DEFAULT NULL,
|
||||||
|
`gmt_modified` datetime DEFAULT NULL,
|
||||||
|
PRIMARY KEY (`id`),
|
||||||
|
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
||||||
|
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='用户评分反馈表';
|
@ -185,6 +185,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
|
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(
|
||||||
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
||||||
)
|
)
|
||||||
|
### Control whether to display the source document of knowledge on the front end.
|
||||||
|
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
|
||||||
|
|
||||||
### SUMMARY_CONFIG Configuration
|
### SUMMARY_CONFIG Configuration
|
||||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
|||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
del self.tokenizer
|
del self.tokenizer
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
_clear_torch_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
torch_imported = False
|
torch_imported = False
|
||||||
|
@ -11,7 +11,7 @@ from pilot.model.parameter import (
|
|||||||
)
|
)
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
return
|
return
|
||||||
del self._embeddings_impl
|
del self._embeddings_impl
|
||||||
self._embeddings_impl = None
|
self._embeddings_impl = None
|
||||||
_clear_torch_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict):
|
def generate_stream(self, params: Dict):
|
||||||
"""Generate stream result, chat scene"""
|
"""Generate stream result, chat scene"""
|
||||||
|
@ -18,6 +18,7 @@ from pilot.logs import logger
|
|||||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||||
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
||||||
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
||||||
|
# TODO: support internlm quantization
|
||||||
model_name = model_params.model_name.lower()
|
model_name = model_params.model_name.lower()
|
||||||
supported_models = ["llama", "baichuan", "vicuna"]
|
supported_models = ["llama", "baichuan", "vicuna"]
|
||||||
return any(m in model_name for m in supported_models)
|
return any(m in model_name for m in supported_models)
|
||||||
|
@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
|
|||||||
ConversationVo,
|
ConversationVo,
|
||||||
MessageVo,
|
MessageVo,
|
||||||
ChatSceneVo,
|
ChatSceneVo,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
DeltaMessage,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
)
|
)
|
||||||
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -85,6 +88,26 @@ def plugins_select_info():
|
|||||||
return plugins_infos
|
return plugins_infos
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_list_info():
|
||||||
|
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||||
|
params: dict = {}
|
||||||
|
for item in dbs:
|
||||||
|
comment = item["comment"]
|
||||||
|
if comment is not None and len(comment) > 0:
|
||||||
|
params.update({item["db_name"]: comment})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_list_info():
|
||||||
|
"""return knowledge space list"""
|
||||||
|
params: dict = {}
|
||||||
|
request = KnowledgeSpaceRequest()
|
||||||
|
spaces = knowledge_service.get_knowledge_space(request)
|
||||||
|
for space in spaces:
|
||||||
|
params.update({space.name: space.desc})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
def knowledge_list():
|
def knowledge_list():
|
||||||
"""return knowledge space list"""
|
"""return knowledge space list"""
|
||||||
params: dict = {}
|
params: dict = {}
|
||||||
@ -363,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_generator(chat),
|
stream_generator(chat, dialogue.incremental, dialogue.model_name),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
)
|
)
|
||||||
@ -401,19 +424,48 @@ async def no_stream_generator(chat):
|
|||||||
yield f"data: {msg}\n\n"
|
yield f"data: {msg}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(chat):
|
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||||
|
"""Generate streaming responses
|
||||||
|
|
||||||
|
Our goal is to generate an openai-compatible streaming responses.
|
||||||
|
Currently, the incremental response is compatible, and the full response will be transformed in the future.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (BaseChat): Chat instance.
|
||||||
|
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
|
||||||
|
model_name (str): The model name
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
_type_: streaming responses
|
||||||
|
"""
|
||||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||||
|
|
||||||
|
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||||
|
previous_response = ""
|
||||||
async for chunk in chat.stream_call():
|
async for chunk in chat.stream_call():
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
chunk, chat.skip_echo_len
|
chunk, chat.skip_echo_len
|
||||||
)
|
)
|
||||||
|
msg = msg.replace("\ufffd", "")
|
||||||
msg = msg.replace("\n", "\\n")
|
if incremental:
|
||||||
yield f"data:{msg}\n\n"
|
incremental_output = msg[len(previous_response) :]
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=stream_id, choices=[choice_data], model=model_name
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
# TODO generate an openai-compatible streaming responses
|
||||||
|
msg = msg.replace("\n", "\\n")
|
||||||
|
yield f"data:{msg}\n\n"
|
||||||
|
previous_response = msg
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
|
if incremental:
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
chat.current_message.add_view_message(msg)
|
chat.current_message.add_view_message(msg)
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
0
pilot/openapi/api_v1/feedback/__init__.py
Normal file
0
pilot/openapi/api_v1/feedback/__init__.py
Normal file
48
pilot/openapi/api_v1/feedback/api_fb_v1.py
Normal file
48
pilot/openapi/api_v1/feedback/api_fb_v1.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from fastapi import APIRouter, Body, Request
|
||||||
|
|
||||||
|
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||||
|
from pilot.openapi.api_v1.feedback.feed_back_db import (
|
||||||
|
ChatFeedBackDao,
|
||||||
|
ChatFeedBackEntity,
|
||||||
|
)
|
||||||
|
from pilot.openapi.api_view_model import Result
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
chat_feed_back = ChatFeedBackDao()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/feedback/find", response_model=Result[FeedBackBody])
|
||||||
|
async def feed_back_find(conv_uid: str, conv_index: int):
|
||||||
|
rt = chat_feed_back.get_chat_feed_back(conv_uid, conv_index)
|
||||||
|
if rt is not None:
|
||||||
|
return Result.succ(
|
||||||
|
FeedBackBody(
|
||||||
|
conv_uid=rt.conv_uid,
|
||||||
|
conv_index=rt.conv_index,
|
||||||
|
question=rt.question,
|
||||||
|
knowledge_space=rt.knowledge_space,
|
||||||
|
score=rt.score,
|
||||||
|
ques_type=rt.ques_type,
|
||||||
|
messages=rt.messages,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/feedback/commit", response_model=Result[bool])
|
||||||
|
async def feed_back_commit(request: Request, feed_back_body: FeedBackBody = Body()):
|
||||||
|
chat_feed_back.create_or_update_chat_feed_back(feed_back_body)
|
||||||
|
return Result.succ(True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/feedback/select", response_model=Result[dict])
|
||||||
|
async def feed_back_select():
|
||||||
|
return Result.succ(
|
||||||
|
{
|
||||||
|
"information": "信息查询",
|
||||||
|
"work_study": "工作学习",
|
||||||
|
"just_fun": "互动闲聊",
|
||||||
|
"others": "其他",
|
||||||
|
}
|
||||||
|
)
|
84
pilot/openapi/api_v1/feedback/feed_back_db.py
Normal file
84
pilot/openapi/api_v1/feedback/feed_back_db.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from pilot.connections.rdbms.base_dao import BaseDao
|
||||||
|
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFeedBackEntity(Base):
|
||||||
|
__tablename__ = "chat_feed_back"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
conv_uid = Column(String(128))
|
||||||
|
conv_index = Column(Integer)
|
||||||
|
score = Column(Integer)
|
||||||
|
ques_type = Column(String(32))
|
||||||
|
question = Column(Text)
|
||||||
|
knowledge_space = Column(String(128))
|
||||||
|
messages = Column(Text)
|
||||||
|
user_name = Column(String(128))
|
||||||
|
gmt_created = Column(DateTime)
|
||||||
|
gmt_modified = Column(DateTime)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
|
||||||
|
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
|
||||||
|
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFeedBackDao(BaseDao):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
|
||||||
|
|
||||||
|
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||||
|
# Todo: We need to have user information first.
|
||||||
|
def_user_name = ""
|
||||||
|
|
||||||
|
session = self.Session()
|
||||||
|
chat_feed_back = ChatFeedBackEntity(
|
||||||
|
conv_uid=feed_back.conv_uid,
|
||||||
|
conv_index=feed_back.conv_index,
|
||||||
|
score=feed_back.score,
|
||||||
|
ques_type=feed_back.ques_type,
|
||||||
|
question=feed_back.question,
|
||||||
|
knowledge_space=feed_back.knowledge_space,
|
||||||
|
messages=feed_back.messages,
|
||||||
|
user_name=def_user_name,
|
||||||
|
gmt_created=datetime.now(),
|
||||||
|
gmt_modified=datetime.now(),
|
||||||
|
)
|
||||||
|
result = (
|
||||||
|
session.query(ChatFeedBackEntity)
|
||||||
|
.filter(ChatFeedBackEntity.conv_uid == feed_back.conv_uid)
|
||||||
|
.filter(ChatFeedBackEntity.conv_index == feed_back.conv_index)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
result.score = feed_back.score
|
||||||
|
result.ques_type = feed_back.ques_type
|
||||||
|
result.question = feed_back.question
|
||||||
|
result.knowledge_space = feed_back.knowledge_space
|
||||||
|
result.messages = feed_back.messages
|
||||||
|
result.user_name = def_user_name
|
||||||
|
result.gmt_created = datetime.now()
|
||||||
|
result.gmt_modified = datetime.now()
|
||||||
|
else:
|
||||||
|
session.merge(chat_feed_back)
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||||
|
session = self.Session()
|
||||||
|
result = (
|
||||||
|
session.query(ChatFeedBackEntity)
|
||||||
|
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||||
|
.filter(ChatFeedBackEntity.conv_index == conv_index)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
session.close()
|
||||||
|
return result
|
25
pilot/openapi/api_v1/feedback/feed_back_model.py
Normal file
25
pilot/openapi/api_v1/feedback/feed_back_model.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from pydantic.main import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class FeedBackBody(BaseModel):
|
||||||
|
"""conv_uid: conversation id"""
|
||||||
|
|
||||||
|
conv_uid: str
|
||||||
|
|
||||||
|
"""conv_index: conversation index"""
|
||||||
|
conv_index: int
|
||||||
|
|
||||||
|
"""question: human question"""
|
||||||
|
question: str
|
||||||
|
|
||||||
|
"""knowledge_space: knowledge space"""
|
||||||
|
knowledge_space: str
|
||||||
|
|
||||||
|
"""score: rating of the llm's answer"""
|
||||||
|
score: int
|
||||||
|
|
||||||
|
"""ques_type: question type"""
|
||||||
|
ques_type: str
|
||||||
|
|
||||||
|
"""messages: rating detail"""
|
||||||
|
messages: str
|
@ -1,5 +1,7 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Generic, Any
|
from typing import TypeVar, Generic, Any, Optional, Literal, List
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
model_name: str = None
|
model_name: str = None
|
||||||
|
|
||||||
|
"""Used to control whether the content is returned incrementally or in full each time.
|
||||||
|
If this parameter is not provided, the default is full return.
|
||||||
|
"""
|
||||||
|
incremental: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
|
|||||||
model_name
|
model_name
|
||||||
"""
|
"""
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
@ -55,6 +56,21 @@ class ChatKnowledge(BaseChat):
|
|||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
embedding_factory=embedding_factory,
|
embedding_factory=embedding_factory,
|
||||||
)
|
)
|
||||||
|
self.prompt_template.template_is_strict = False
|
||||||
|
|
||||||
|
async def stream_call(self):
|
||||||
|
input_values = self.generate_input_values()
|
||||||
|
async for output in super().stream_call():
|
||||||
|
# Source of knowledge file
|
||||||
|
relations = input_values.get("relations")
|
||||||
|
if (
|
||||||
|
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||||
|
and type(relations) == list
|
||||||
|
and len(relations) > 0
|
||||||
|
and hasattr(output, "text")
|
||||||
|
):
|
||||||
|
output.text = output.text + "\trelations:" + ",".join(relations)
|
||||||
|
yield output
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
if self.space_context:
|
if self.space_context:
|
||||||
@ -69,7 +85,14 @@ class ChatKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
context = [d.page_content for d in docs]
|
context = [d.page_content for d in docs]
|
||||||
context = context[: self.max_token]
|
context = context[: self.max_token]
|
||||||
input_values = {"context": context, "question": self.current_user_input}
|
relations = list(
|
||||||
|
set([os.path.basename(d.metadata.get("source")) for d in docs])
|
||||||
|
)
|
||||||
|
input_values = {
|
||||||
|
"context": context,
|
||||||
|
"question": self.current_user_input,
|
||||||
|
"relations": relations,
|
||||||
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -30,6 +30,7 @@ from pilot.server.llm_manage.api import router as llm_manage_api
|
|||||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||||
from pilot.openapi.base import validation_exception_handler
|
from pilot.openapi.base import validation_exception_handler
|
||||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||||
|
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||||
@ -74,6 +75,7 @@ app.include_router(api_v1, prefix="/api")
|
|||||||
app.include_router(knowledge_router, prefix="/api")
|
app.include_router(knowledge_router, prefix="/api")
|
||||||
app.include_router(api_editor_route_v1, prefix="/api")
|
app.include_router(api_editor_route_v1, prefix="/api")
|
||||||
app.include_router(llm_manage_api, prefix="/api")
|
app.include_router(llm_manage_api, prefix="/api")
|
||||||
|
app.include_router(api_fb_v1, prefix="/api")
|
||||||
|
|
||||||
# app.include_router(api_v1)
|
# app.include_router(api_v1)
|
||||||
app.include_router(knowledge_router)
|
app.include_router(knowledge_router)
|
||||||
|
@ -1,10 +1,22 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_model_cache(device="cuda"):
|
||||||
|
try:
|
||||||
|
# clear torch cache
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_clear_torch_cache(device)
|
||||||
|
except ImportError:
|
||||||
|
logger.warn("Torch not installed, skip clear torch cache")
|
||||||
|
# TODO clear other cache
|
||||||
|
|
||||||
|
|
||||||
def _clear_torch_cache(device="cuda"):
|
def _clear_torch_cache(device="cuda"):
|
||||||
import gc
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if device != "cpu":
|
if device != "cpu":
|
||||||
@ -14,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
|
|||||||
|
|
||||||
empty_cache()
|
empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||||
elif torch.has_cuda:
|
elif torch.has_cuda:
|
||||||
device_count = torch.cuda.device_count()
|
device_count = torch.cuda.device_count()
|
||||||
for device_id in range(device_count):
|
for device_id in range(device_count):
|
||||||
cuda_device = f"cuda:{device_id}"
|
cuda_device = f"cuda:{device_id}"
|
||||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||||
with torch.cuda.device(cuda_device):
|
with torch.cuda.device(cuda_device):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
else:
|
else:
|
||||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||||
|
Loading…
Reference in New Issue
Block a user