mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-27 12:29:29 +00:00
lint: fix
This commit is contained in:
parent
4040e1592a
commit
e95696c23d
@ -1,21 +1,21 @@
|
||||
import markdown2
|
||||
import markdown2
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def datas_to_table_html(data):
|
||||
df = pd.DataFrame(data[1:], columns=data[0])
|
||||
table_style = """<style>
|
||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||
</style>"""
|
||||
html_table = df.to_html(index=False, escape=False)
|
||||
html_table = df.to_html(index=False, escape=False)
|
||||
|
||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||
|
||||
return html.replace("\n", " ")
|
||||
|
||||
|
||||
|
||||
def generate_markdown_table(data):
|
||||
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n """
|
||||
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n"""
|
||||
# 获取表格列数
|
||||
num_cols = len(data[0])
|
||||
# 生成表头
|
||||
@ -41,6 +41,7 @@ def generate_markdown_table(data):
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def generate_htm_table(data):
|
||||
markdown_text = generate_markdown_table(data)
|
||||
html_table = markdown2.markdown(markdown_text, extras=["tables"])
|
||||
@ -53,4 +54,4 @@ if __name__ == "__main__":
|
||||
|
||||
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
|
||||
|
||||
print(table_style.replace("\n", " "))
|
||||
print(table_style.replace("\n", " "))
|
||||
|
@ -1,8 +1,9 @@
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
SINGLE ="###"
|
||||
SINGLE = "###"
|
||||
TWO = "</s>"
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
@ -30,16 +30,16 @@ class Database:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -119,7 +119,7 @@ class Database:
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> Database:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@ -148,7 +148,7 @@ class Database:
|
||||
|
||||
self._metadata = MetaData()
|
||||
# sql = f"use {db_name}"
|
||||
sql = text(f'use `{db_name}`')
|
||||
sql = text(f"use `{db_name}`")
|
||||
session.execute(sql)
|
||||
|
||||
# 处理表信息数据
|
||||
@ -159,13 +159,17 @@ class Database:
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=db_name)
|
||||
+ (self._inspector.get_view_names(schema=db_name) if self.view_support else [])
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=db_name)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self, session) -> str:
|
||||
return session.execute(text('SELECT DATABASE()')).scalar()
|
||||
return session.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
def table_simple_info(self, session):
|
||||
_sql = f"""
|
||||
@ -201,7 +205,7 @@ class Database:
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@ -214,7 +218,7 @@ class Database:
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
@ -303,6 +307,10 @@ class Database:
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(' show databases;'))
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]]
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
@ -47,7 +47,6 @@ class Config(metaclass=Singleton):
|
||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||
|
||||
|
||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
|
||||
@ -104,8 +103,17 @@ class Config(metaclass=Singleton):
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
|
||||
### TODO Adapt to multiple types of libraries
|
||||
self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) ,
|
||||
engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
||||
self.local_db = Database.from_uri(
|
||||
"mysql+pymysql://"
|
||||
+ self.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ self.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ self.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(self.LOCAL_DB_PORT),
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import pymysql
|
||||
|
||||
|
||||
class MySQLOperator:
|
||||
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
||||
Args:
|
||||
@ -31,8 +32,7 @@ class MySQLOperator:
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
|
||||
def run_sql(self, db_name:str, sql:str, fetch: str = "all"):
|
||||
def run_sql(self, db_name: str, sql: str, fetch: str = "all"):
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("USE " + db_name)
|
||||
cursor.execute(sql)
|
||||
|
@ -44,6 +44,7 @@ class Conversation:
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
last_user_input: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator, validator,Extra
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
@ -17,13 +17,9 @@ from typing import (
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BaseChatHistoryMemory(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.conversations:List[OnceConversation] = []
|
||||
self.conversations: List[OnceConversation] = []
|
||||
|
||||
@abstractmethod
|
||||
def messages(self) -> List[OnceConversation]: # type: ignore
|
||||
@ -33,8 +29,6 @@ class BaseChatHistoryMemory(ABC):
|
||||
def append(self, message: OnceConversation) -> None:
|
||||
"""Append the message to the record in the local file"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from the local file"""
|
||||
|
||||
|
@ -5,31 +5,33 @@ import datetime
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pathlib import Path
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.message import OnceConversation, conversation_from_dict,conversations_to_dict
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.message import (
|
||||
OnceConversation,
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
def __init__(self, chat_session_id:str):
|
||||
def __init__(self, chat_session_id: str):
|
||||
now = datetime.datetime.now()
|
||||
date_string = now.strftime("%Y%m%d")
|
||||
path: str = f"{CFG.message_dir}/{date_string}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
dir_path = Path(path)
|
||||
self.file_path = Path(dir_path / f"{chat_session_id}.json")
|
||||
self.file_path = Path(dir_path / f"{chat_session_id}.json")
|
||||
if not self.file_path.exists():
|
||||
self.file_path.touch()
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
|
||||
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
items = json.loads(self.file_path.read_text())
|
||||
history:List[OnceConversation] = []
|
||||
history: List[OnceConversation] = []
|
||||
for onece in items:
|
||||
messages = conversation_from_dict(onece)
|
||||
history.append(messages)
|
||||
@ -38,8 +40,10 @@ class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
historys = self.messages()
|
||||
historys.append(once_message)
|
||||
self.file_path.write_text(json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4), encoding="UTF-8")
|
||||
self.file_path.write_text(
|
||||
json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4),
|
||||
encoding="UTF-8",
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
|
||||
|
@ -41,16 +41,16 @@ class BaseOutputParser(ABC):
|
||||
text = text.lower()
|
||||
respObj = json.loads(text)
|
||||
|
||||
xx = respObj['response']
|
||||
xx = xx.strip(b'\x00'.decode())
|
||||
xx = respObj["response"]
|
||||
xx = xx.strip(b"\x00".decode())
|
||||
respObj_ex = json.loads(xx)
|
||||
if respObj_ex['error_code'] == 0:
|
||||
all_text = respObj_ex['text']
|
||||
if respObj_ex["error_code"] == 0:
|
||||
all_text = respObj_ex["text"]
|
||||
### 解析返回文本,获取AI回复部分
|
||||
tmpResp = all_text.split(sep)
|
||||
last_index = -1
|
||||
for i in range(len(tmpResp)):
|
||||
if tmpResp[i].find('assistant:') != -1:
|
||||
if tmpResp[i].find("assistant:") != -1:
|
||||
last_index = i
|
||||
ai_response = tmpResp[last_index]
|
||||
ai_response = ai_response.replace("assistant:", "")
|
||||
@ -60,9 +60,7 @@ class BaseOutputParser(ABC):
|
||||
print("un_stream clear response:{}", ai_response)
|
||||
return ai_response
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex['error_code']);
|
||||
|
||||
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
def parse_model_server_out(self, response) -> str:
|
||||
"""
|
||||
|
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@ -8,7 +6,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from pilot.scene.base_message import BaseMessage,HumanMessage,AIMessage, SystemMessage
|
||||
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
@ -29,7 +27,6 @@ def get_buffer_string(
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
@ -39,6 +36,7 @@ class PromptValue(BaseModel, ABC):
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
messages: List[BaseMessage]
|
||||
|
||||
@ -48,4 +46,4 @@ class ChatPromptValue(PromptValue):
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return self.messages
|
||||
return self.messages
|
||||
|
@ -3,8 +3,8 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
class PromptGenerator:
|
||||
"""
|
||||
generating custom prompt strings based on constraints;
|
||||
Compatible with AutoGpt Plugin;
|
||||
generating custom prompt strings based on constraints;
|
||||
Compatible with AutoGpt Plugin;
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -22,8 +22,6 @@ class PromptGenerator:
|
||||
self.role = "AI"
|
||||
self.response_format = None
|
||||
|
||||
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
@ -51,4 +49,4 @@ class PromptGenerator:
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
self.commands.append(command)
|
||||
self.commands.append(command)
|
||||
|
@ -8,6 +8,7 @@ from pilot.common.formatting import formatter
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
try:
|
||||
@ -32,22 +33,23 @@ class PromptTemplate(BaseModel, ABC):
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
template_scene: str
|
||||
|
||||
template_define:str
|
||||
template_define: str
|
||||
"""this template define"""
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
response_format:str
|
||||
response_format: str
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
output_parser: BaseOutputParser = None
|
||||
""""""
|
||||
sep:str = SeparatorStyle.SINGLE.value
|
||||
sep: str = SeparatorStyle.SINGLE.value
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@ -96,10 +98,8 @@ class PromptTemplate(BaseModel, ABC):
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
"""
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
|
||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
@ -207,6 +207,7 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
text: str
|
||||
|
||||
@ -219,7 +220,6 @@ class StringPromptValue(PromptValue):
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
|
||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt should expose the format method, returning a prompt."""
|
||||
|
||||
@ -360,4 +360,4 @@ class PromptTemplate(StringPromptTemplate):
|
||||
|
||||
|
||||
# For backwards compatibility.
|
||||
Prompt = PromptTemplate
|
||||
Prompt = PromptTemplate
|
||||
|
@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDb = "chat_with_db"
|
||||
ChatExecution = "chat_execution"
|
||||
ChatKnowledge = "chat_default_knowledge"
|
||||
ChatNewKnowledge = "chat_new_knowledge"
|
||||
ChatNormal = "chat_normal"
|
||||
ChatNormal = "chat_normal"
|
||||
|
@ -20,7 +20,7 @@ from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
@ -32,8 +32,10 @@ from pilot.configs.config import Config
|
||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
CFG = Config()
|
||||
class BaseChat( ABC):
|
||||
chat_scene:str = None
|
||||
|
||||
|
||||
class BaseChat(ABC):
|
||||
chat_scene: str = None
|
||||
llm_model: Any = None
|
||||
temperature: float = 0.6
|
||||
max_new_tokens: int = 1024
|
||||
@ -42,17 +44,20 @@ class BaseChat( ABC):
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, chat_mode, chat_session_id, current_user_input):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input:str = current_user_input
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
### TODO
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
]
|
||||
self.history_message: List[OnceConversation] = []
|
||||
self.current_message: OnceConversation = OnceConversation()
|
||||
self.current_tokens_used: int = 0
|
||||
@ -69,7 +74,6 @@ class BaseChat( ABC):
|
||||
def chat_type(self) -> str:
|
||||
raise NotImplementedError("Not supported for this chat type.")
|
||||
|
||||
|
||||
def call(self):
|
||||
pass
|
||||
|
||||
@ -88,7 +92,7 @@ class BaseChat( ABC):
|
||||
"""
|
||||
return self.memory.messages()
|
||||
|
||||
def generate(self, p)->str:
|
||||
def generate(self, p) -> str:
|
||||
"""
|
||||
generate context for LLM input
|
||||
Args:
|
||||
|
@ -15,6 +15,7 @@ from typing import (
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
@ -37,7 +38,6 @@ class BaseMessage(BaseModel):
|
||||
"""Type of the message, used for serialization."""
|
||||
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
|
||||
@ -49,7 +49,6 @@ class HumanMessage(BaseMessage):
|
||||
return "human"
|
||||
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the AI."""
|
||||
|
||||
@ -81,8 +80,6 @@ class SystemMessage(BaseMessage):
|
||||
return "system"
|
||||
|
||||
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""Output of a single generation."""
|
||||
|
||||
@ -94,7 +91,6 @@ class Generation(BaseModel):
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""Output of a single generation."""
|
||||
|
||||
@ -126,7 +122,6 @@ class LLMResult(BaseModel):
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
@ -149,6 +144,5 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
raise ValueError(f"Got unexpected type: {_type}")
|
||||
|
||||
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
@ -14,7 +14,13 @@ from sqlalchemy import (
|
||||
)
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -25,7 +31,11 @@ from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
)
|
||||
from pilot.common.markdown_text import generate_markdown_table, generate_htm_table, datas_to_table_html
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
from pilot.scene.chat_db.prompt import chat_db_prompt
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
||||
@ -39,8 +49,7 @@ class ChatWithDb(BaseChat):
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
|
||||
if not db_name:
|
||||
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
|
||||
@ -71,7 +80,9 @@ class ChatWithDb(BaseChat):
|
||||
### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||
if self.history_message:
|
||||
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||
logger.info(f"There are already {len(self.history_message)} rounds of conversations!")
|
||||
logger.info(
|
||||
f"There are already {len(self.history_message)} rounds of conversations!"
|
||||
)
|
||||
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
@ -87,32 +98,56 @@ class ChatWithDb(BaseChat):
|
||||
try:
|
||||
### 走非流式的模型服务接口
|
||||
|
||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120)
|
||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
)
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
|
||||
result = self.database.run(self.db_connect, prompt_define_response.sql)
|
||||
|
||||
|
||||
if hasattr(prompt_define_response, 'thoughts'):
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if prompt_define_response.thoughts.get("speak"):
|
||||
self.current_message.add_view_message(
|
||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("speak"),result))
|
||||
self.prompt_template.output_parser.parse_view_response(
|
||||
prompt_define_response.thoughts.get("speak"), result
|
||||
)
|
||||
)
|
||||
elif prompt_define_response.thoughts.get("reasoning"):
|
||||
self.current_message.add_view_message(
|
||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("reasoning"), result))
|
||||
self.prompt_template.output_parser.parse_view_response(
|
||||
prompt_define_response.thoughts.get("reasoning"), result
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.current_message.add_view_message(
|
||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts, result))
|
||||
self.prompt_template.output_parser.parse_view_response(
|
||||
prompt_define_response.thoughts, result
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.current_message.add_view_message(
|
||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response, result))
|
||||
self.prompt_template.output_parser.parse_view_response(
|
||||
prompt_define_response, result
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
self.current_message.add_view_message(f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """)
|
||||
self.current_message.add_view_message(
|
||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
### 对话记录存储
|
||||
self.memory.append(self.current_message)
|
||||
|
||||
@ -121,10 +156,10 @@ class ChatWithDb(BaseChat):
|
||||
# 单论对话只能有一次User 记录 和一次 AI 记录
|
||||
# TODO 推理过程前端展示。。。
|
||||
for message in self.current_message.messages:
|
||||
if (isinstance(message, HumanMessage)):
|
||||
if isinstance(message, HumanMessage):
|
||||
ret[-1][-2] = message.content
|
||||
# 是否展示推理过程
|
||||
if (isinstance(message, ViewMessage)):
|
||||
if isinstance(message, ViewMessage):
|
||||
ret[-1][-1] = message.content
|
||||
|
||||
return ret
|
||||
@ -132,34 +167,51 @@ class ChatWithDb(BaseChat):
|
||||
# 暂时为了兼容前端
|
||||
def current_ai_response(self) -> str:
|
||||
for message in self.current_message.messages:
|
||||
if message.type == 'view':
|
||||
if message.type == "view":
|
||||
return message.content
|
||||
return None
|
||||
|
||||
def generate_llm_text(self) -> str:
|
||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||
### 线处理历史信息
|
||||
if (len(self.history_message) > self.chat_retention_rounds):
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||
for first_message in self.history_message[0].messages:
|
||||
if not isinstance(first_message, ViewMessage):
|
||||
text += first_message.type + ":" + first_message.content + self.prompt_template.sep
|
||||
text += (
|
||||
first_message.type
|
||||
+ ":"
|
||||
+ first_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for last_message in self.history_message[-index:].messages:
|
||||
if not isinstance(last_message, ViewMessage):
|
||||
text += last_message.type + ":" + last_message.content + self.prompt_template.sep
|
||||
text += (
|
||||
last_message.type
|
||||
+ ":"
|
||||
+ last_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
else:
|
||||
### 直接历史记录拼接
|
||||
for conversation in self.history_message:
|
||||
for message in conversation.messages:
|
||||
if not isinstance(message, ViewMessage):
|
||||
text += message.type + ":" + message.content + self.prompt_template.sep
|
||||
text += (
|
||||
message.type
|
||||
+ ":"
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
### current conversation
|
||||
for now_message in self.current_message.messages:
|
||||
text += now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
text += (
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
|
@ -1,25 +1,24 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Dict,
|
||||
NamedTuple
|
||||
)
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
class SqlAction(NamedTuple):
|
||||
sql: str
|
||||
thoughts: Dict
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
|
||||
class DbChatOutputParser(BaseOutputParser):
|
||||
|
||||
def __init__(self, sep:str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out )
|
||||
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
def parse_model_server_out(self, response) -> str:
|
||||
return super().parse_model_server_out(response)
|
||||
@ -31,20 +30,20 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
if "```" in cleaned_output:
|
||||
cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json"):]
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```"):]
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||
logger.info("illegal json processing")
|
||||
json_pattern = r'{(.+?)}'
|
||||
json_pattern = r"{(.+?)}"
|
||||
m = re.search(json_pattern, cleaned_output)
|
||||
if m:
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
response = json.loads(cleaned_output)
|
||||
sql, thoughts = response["sql"], response["thoughts"]
|
||||
|
@ -45,7 +45,7 @@ RESPONSE_FORMAT = {
|
||||
"reasoning": "reasoning",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"sql": "SQL Query to run"
|
||||
"sql": "SQL Query to run",
|
||||
}
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
@ -59,7 +59,9 @@ chat_db_prompt = PromptTemplate(
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
||||
output_parser=DbChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})
|
||||
|
@ -4,8 +4,10 @@ from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.scene.base import ChatScene
|
||||
|
||||
|
||||
class ChatWithPlugin(BaseChat):
|
||||
chat_scene: str= ChatScene.ChatExecution.value
|
||||
chat_scene: str = ChatScene.ChatExecution.value
|
||||
|
||||
def __init__(self, chat_mode, chat_session_id, current_user_input):
|
||||
super().__init__(chat_mode, chat_session_id, current_user_input)
|
||||
|
||||
@ -23,4 +25,4 @@ class ChatWithPlugin(BaseChat):
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatExecution.value
|
||||
return ChatScene.ChatExecution.value
|
||||
|
@ -1,19 +1,17 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.scene.chat_db.chat import ChatWithDb
|
||||
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||
|
||||
class ChatFactory(metaclass=Singleton):
|
||||
|
||||
class ChatFactory(metaclass=Singleton):
|
||||
@staticmethod
|
||||
def get_implementation(chat_mode, **kwargs):
|
||||
|
||||
chat_classes = BaseChat.__subclasses__()
|
||||
implementation = None
|
||||
for cls in chat_classes:
|
||||
if(cls.chat_scene == chat_mode):
|
||||
if cls.chat_scene == chat_mode:
|
||||
implementation = cls(**kwargs)
|
||||
if(implementation == None):
|
||||
raise Exception('Invalid implementation name:' + chat_mode)
|
||||
return implementation
|
||||
if implementation == None:
|
||||
raise Exception("Invalid implementation name:" + chat_mode)
|
||||
return implementation
|
||||
|
@ -9,12 +9,20 @@ from typing import (
|
||||
List,
|
||||
)
|
||||
|
||||
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ViewMessage,
|
||||
messages_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
|
||||
class OnceConversation:
|
||||
"""
|
||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -26,7 +34,9 @@ class OnceConversation:
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Add a user message to the store"""
|
||||
has_message = any(isinstance(instance, HumanMessage) for instance in self.messages)
|
||||
has_message = any(
|
||||
isinstance(instance, HumanMessage) for instance in self.messages
|
||||
)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Human message")
|
||||
self.messages.append(HumanMessage(content=message))
|
||||
@ -38,6 +48,7 @@ class OnceConversation:
|
||||
raise ValueError("Already Have Ai message")
|
||||
self.messages.append(AIMessage(content=message))
|
||||
""" """
|
||||
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
@ -50,7 +61,7 @@ class OnceConversation:
|
||||
|
||||
def set_start_time(self, datatime: datetime):
|
||||
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.start_date = dt_str;
|
||||
self.start_date = dt_str
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
@ -71,7 +82,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
"start_date": start_str,
|
||||
"cost": once.cost if once.cost else 0,
|
||||
"tokens": once.tokens if once.tokens else 0,
|
||||
"messages": messages_to_dict(once.messages)
|
||||
"messages": messages_to_dict(once.messages),
|
||||
}
|
||||
|
||||
|
||||
@ -81,10 +92,10 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||
|
||||
def conversation_from_dict(once: dict) -> OnceConversation:
|
||||
conversation = OnceConversation()
|
||||
conversation.cost = once.get('cost', 0)
|
||||
conversation.tokens = once.get('tokens', 0)
|
||||
conversation.start_date = once.get('start_date', '')
|
||||
conversation.chat_order = int(once.get('chat_order'))
|
||||
print(once.get('messages'))
|
||||
conversation.messages = messages_from_dict(once.get('messages', []))
|
||||
conversation.cost = once.get("cost", 0)
|
||||
conversation.tokens = once.get("tokens", 0)
|
||||
conversation.start_date = once.get("start_date", "")
|
||||
conversation.chat_order = int(once.get("chat_order"))
|
||||
print(once.get("messages"))
|
||||
conversation.messages = messages_from_dict(once.get("messages", []))
|
||||
return conversation
|
||||
|
@ -137,13 +137,15 @@ def load_demo(url_params, request: gr.Request):
|
||||
unique_id = uuid.uuid1()
|
||||
state.conv_id = str(unique_id)
|
||||
|
||||
return (state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True))
|
||||
return (
|
||||
state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True),
|
||||
)
|
||||
|
||||
|
||||
def get_conv_log_filename():
|
||||
@ -203,30 +205,31 @@ def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||
return ChatScene.ChatExecution
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
return ChatScene.ChatNormal
|
||||
|
||||
|
||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
def http_bot(
|
||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
if ChatScene.ChatWithDb == scene:
|
||||
logger.info("基于DB对话走新的模式!")
|
||||
chat_param ={
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
else:
|
||||
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
@ -242,7 +245,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||
new_state.append_message(
|
||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||
)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
@ -251,7 +256,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
state = new_state
|
||||
|
||||
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
@ -263,16 +267,25 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
if mode == conversation_types["custome"] and not db_selector:
|
||||
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
|
||||
persist_dir = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
|
||||
)
|
||||
print("向量数据库持久化地址: ", persist_dir)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
|
||||
local_persist=False,
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
},
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, 1)
|
||||
context = [d.page_content for d in docs]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
@ -285,7 +298,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
"prompt": prompt,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
"stop": state.sep
|
||||
if state.sep_style == SeparatorStyle.SINGLE
|
||||
else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
@ -295,8 +310,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers, json=payload, stream=True, timeout=20)
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=20,
|
||||
)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
@ -309,12 +329,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
@ -573,7 +604,6 @@ def knowledge_embedding_store(vs_id, files):
|
||||
)
|
||||
knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
|
||||
logger.info("knowledge embedding success")
|
||||
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user