mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 11:29:29 +00:00
Merge branch 'summary' into dev
# Conflicts: # pilot/model/llm_out/guanaco_llm.py # pilot/out_parser/base.py # pilot/scene/base_chat.py
This commit is contained in:
commit
7e6cf9c9e0
@ -55,8 +55,6 @@ def fix_and_parse_json(
|
|||||||
logger.error("参数解析错误", e)
|
logger.error("参数解析错误", e)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def correct_json(json_to_load: str) -> str:
|
def correct_json(json_to_load: str) -> str:
|
||||||
"""
|
"""
|
||||||
Correct common JSON errors.
|
Correct common JSON errors.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
class FixedSizeDict(OrderedDict):
|
class FixedSizeDict(OrderedDict):
|
||||||
def __init__(self, max_size):
|
def __init__(self, max_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -11,6 +12,7 @@ class FixedSizeDict(OrderedDict):
|
|||||||
self.popitem(last=False)
|
self.popitem(last=False)
|
||||||
super().__setitem__(key, value)
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
class FixedSizeList:
|
class FixedSizeList:
|
||||||
def __init__(self, max_size):
|
def __init__(self, max_size):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import sqlparse
|
import sqlparse
|
||||||
import regex as re
|
import regex as re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
@ -283,16 +283,16 @@ class Database:
|
|||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
def __write(self, session, write_sql):
|
def __write(self, session, write_sql):
|
||||||
print(f"Write[{write_sql}]")
|
print(f"Write[{write_sql}]")
|
||||||
db_cache = self.get_session_db(session)
|
db_cache = self.get_session_db(session)
|
||||||
result = session.execute(text(write_sql))
|
result = session.execute(text(write_sql))
|
||||||
session.commit()
|
session.commit()
|
||||||
#TODO Subsequent optimization of dynamically specified database submission loss target problem
|
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||||
session.execute(text(f"use `{db_cache}`"))
|
session.execute(text(f"use `{db_cache}`"))
|
||||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|
||||||
def __query(self,session, query, fetch: str = "all"):
|
def __query(self, session, query, fetch: str = "all"):
|
||||||
"""
|
"""
|
||||||
only for query
|
only for query
|
||||||
Args:
|
Args:
|
||||||
@ -390,37 +390,44 @@ class Database:
|
|||||||
cmd_type = parts[0]
|
cmd_type = parts[0]
|
||||||
|
|
||||||
# 根据命令类型进行处理
|
# 根据命令类型进行处理
|
||||||
if cmd_type == 'insert':
|
if cmd_type == "insert":
|
||||||
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
|
match = re.match(
|
||||||
|
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||||
|
)
|
||||||
if match:
|
if match:
|
||||||
table_name, columns, values = match.groups()
|
table_name, columns, values = match.groups()
|
||||||
# 将字段列表和值列表分割为单独的字段和值
|
# 将字段列表和值列表分割为单独的字段和值
|
||||||
columns = columns.split(',')
|
columns = columns.split(",")
|
||||||
values = values.split(',')
|
values = values.split(",")
|
||||||
# 构造 WHERE 子句
|
# 构造 WHERE 子句
|
||||||
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
|
where_clause = " AND ".join(
|
||||||
return f'SELECT * FROM {table_name} WHERE {where_clause}'
|
[
|
||||||
|
f"{col.strip()}={val.strip()}"
|
||||||
|
for col, val in zip(columns, values)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||||
|
|
||||||
elif cmd_type == 'delete':
|
elif cmd_type == "delete":
|
||||||
table_name = parts[2] # delete from <table_name> ...
|
table_name = parts[2] # delete from <table_name> ...
|
||||||
# 返回一个select语句,它选择该表的所有数据
|
# 返回一个select语句,它选择该表的所有数据
|
||||||
return f'SELECT * FROM {table_name}'
|
return f"SELECT * FROM {table_name}"
|
||||||
|
|
||||||
elif cmd_type == 'update':
|
elif cmd_type == "update":
|
||||||
table_name = parts[1]
|
table_name = parts[1]
|
||||||
set_idx = parts.index('set')
|
set_idx = parts.index("set")
|
||||||
where_idx = parts.index('where')
|
where_idx = parts.index("where")
|
||||||
# 截取 `set` 子句中的字段名
|
# 截取 `set` 子句中的字段名
|
||||||
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
|
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
|
||||||
# 截取 `where` 之后的条件语句
|
# 截取 `where` 之后的条件语句
|
||||||
where_clause = ' '.join(parts[where_idx + 1:])
|
where_clause = " ".join(parts[where_idx + 1 :])
|
||||||
# 返回一个select语句,它选择更新的数据
|
# 返回一个select语句,它选择更新的数据
|
||||||
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
|
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||||
|
|
||||||
def __sql_parse(self, sql):
|
def __sql_parse(self, sql):
|
||||||
sql = sql.strip()
|
sql = sql.strip()
|
||||||
parsed = sqlparse.parse(sql)[0]
|
parsed = sqlparse.parse(sql)[0]
|
||||||
sql_type = parsed.get_type()
|
sql_type = parsed.get_type()
|
||||||
|
|
||||||
@ -429,8 +436,6 @@ class Database:
|
|||||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||||
return parsed, ttype, sql_type
|
return parsed, ttype, sql_type
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_indexes(self, table_name):
|
def get_indexes(self, table_name):
|
||||||
"""Get table indexes about specified table."""
|
"""Get table indexes about specified table."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
|
@ -103,8 +103,12 @@ class Config(metaclass=Singleton):
|
|||||||
else:
|
else:
|
||||||
self.plugins_denylist = []
|
self.plugins_denylist = []
|
||||||
### Native SQL Execution Capability Control Configuration
|
### Native SQL Execution Capability Control Configuration
|
||||||
self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True"
|
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||||
self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True"
|
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
|
||||||
|
)
|
||||||
|
self.NATIVE_SQL_CAN_RUN_WRITE = (
|
||||||
|
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
||||||
|
)
|
||||||
|
|
||||||
### Local database connection configuration
|
### Local database connection configuration
|
||||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||||
|
@ -11,11 +11,9 @@ class BaseConnect(BaseModel, ABC):
|
|||||||
type
|
type
|
||||||
driver: str
|
driver: str
|
||||||
|
|
||||||
|
|
||||||
def get_session(self, db_name: str):
|
def get_session(self, db_name: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_table_names(self) -> Iterable[str]:
|
def get_table_names(self) -> Iterable[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -11,8 +11,7 @@ class MySQLConnect(RDBMSDatabase):
|
|||||||
Usage:
|
Usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type:str = "MySQL"
|
type: str = "MySQL"
|
||||||
connect_url = "mysql+pymysql://"
|
connect_url = "mysql+pymysql://"
|
||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class Conversation:
|
|||||||
|
|
||||||
|
|
||||||
conv_default = Conversation(
|
conv_default = Conversation(
|
||||||
system = None,
|
system=None,
|
||||||
roles=("human", "ai"),
|
roles=("human", "ai"),
|
||||||
messages=[],
|
messages=[],
|
||||||
offset=0,
|
offset=0,
|
||||||
@ -298,7 +298,6 @@ chat_mode_title = {
|
|||||||
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
|
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
|
||||||
"chat_use_plugin": get_lang_text("chat_use_plugin"),
|
"chat_use_plugin": get_lang_text("chat_use_plugin"),
|
||||||
"knowledge_qa": get_lang_text("knowledge_qa"),
|
"knowledge_qa": get_lang_text("knowledge_qa"),
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation_sql_mode = {
|
conversation_sql_mode = {
|
||||||
|
@ -29,7 +29,6 @@ lang_dicts = {
|
|||||||
"url_input_label": "输入网页地址",
|
"url_input_label": "输入网页地址",
|
||||||
"add_as_new_klg": "添加为新知识库",
|
"add_as_new_klg": "添加为新知识库",
|
||||||
"add_file_to_klg": "向知识库中添加文件",
|
"add_file_to_klg": "向知识库中添加文件",
|
||||||
|
|
||||||
"upload_file": "上传文件",
|
"upload_file": "上传文件",
|
||||||
"add_file": "添加文件",
|
"add_file": "添加文件",
|
||||||
"upload_and_load_to_klg": "上传并加载到知识库",
|
"upload_and_load_to_klg": "上传并加载到知识库",
|
||||||
|
@ -9,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def chatglm_generate_stream(
|
def chatglm_generate_stream(
|
||||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||||
):
|
):
|
||||||
"""Generate text using chatglm model's chat api"""
|
"""Generate text using chatglm model's chat api"""
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
@ -57,7 +57,7 @@ def chatglm_generate_stream(
|
|||||||
# i = 0
|
# i = 0
|
||||||
|
|
||||||
for i, (response, new_hist) in enumerate(
|
for i, (response, new_hist) in enumerate(
|
||||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||||
):
|
):
|
||||||
if echo:
|
if echo:
|
||||||
output = query + " " + response
|
output = query + " " + response
|
||||||
|
@ -83,9 +83,7 @@ class BaseOutputParser(ABC):
|
|||||||
output = self.__post_process_code(output)
|
output = self.__post_process_code(output)
|
||||||
yield output
|
yield output
|
||||||
else:
|
else:
|
||||||
output = (
|
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||||
data["text"] + f" (error_code: {data['error_code']})"
|
|
||||||
)
|
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
def parse_model_nostream_resp(self, response, sep: str):
|
def parse_model_nostream_resp(self, response, sep: str):
|
||||||
|
@ -133,10 +133,8 @@ class PluginPromptGenerator:
|
|||||||
else:
|
else:
|
||||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||||
|
|
||||||
|
def generate_commands_string(self) -> str:
|
||||||
def generate_commands_string(self)->str:
|
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt_string(self) -> str:
|
def generate_prompt_string(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -33,13 +33,13 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
template_scene: Optional[str]
|
template_scene: Optional[str]
|
||||||
|
|
||||||
template_define: Optional[str]
|
template_define: Optional[str]
|
||||||
"""this template define"""
|
"""this template define"""
|
||||||
template: Optional[str]
|
template: Optional[str]
|
||||||
"""The prompt template."""
|
"""The prompt template."""
|
||||||
template_format: str = "f-string"
|
template_format: str = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
response_format: Optional[str]
|
response_format: Optional[str]
|
||||||
"""default use stream out"""
|
"""default use stream out"""
|
||||||
stream_out: bool = True
|
stream_out: bool = True
|
||||||
""""""
|
""""""
|
||||||
@ -62,7 +62,9 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
if self.template:
|
if self.template:
|
||||||
if self.response_format:
|
if self.response_format:
|
||||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||||
|
self.template, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def add_goals(self, goal: str) -> None:
|
def add_goals(self, goal: str) -> None:
|
||||||
self.goals.append(goal)
|
self.goals.append(goal)
|
||||||
|
@ -129,7 +129,7 @@ class BaseChat(ABC):
|
|||||||
def stream_call(self):
|
def stream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
|
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
@ -141,7 +141,7 @@ class BaseChat(ABC):
|
|||||||
stream=True,
|
stream=True,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
return response;
|
return response
|
||||||
|
|
||||||
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
||||||
|
|
||||||
|
@ -21,15 +21,21 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
|
def __init__(
|
||||||
|
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
||||||
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.ChatWithDbExecute,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.ChatWithDbExecute,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
if not db_name:
|
if not db_name:
|
||||||
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!")
|
raise ValueError(
|
||||||
|
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||||
|
)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
self.database = CFG.local_db
|
self.database = CFG.local_db
|
||||||
# 准备DB信息(拿到指定库的链接)
|
# 准备DB信息(拿到指定库的链接)
|
||||||
@ -40,9 +46,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
try:
|
try:
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
"Could not import DBSummaryClient. "
|
|
||||||
)
|
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
"top_k": str(self.top_k),
|
"top_k": str(self.top_k),
|
||||||
|
@ -20,9 +20,8 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
def __init__(self, sep: str, is_stream_out: bool):
|
def __init__(self, sep: str, is_stream_out: bool):
|
||||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text):
|
def parse_prompt_response(self, model_out_text):
|
||||||
clean_str = super().parse_prompt_response(model_out_text);
|
clean_str = super().parse_prompt_response(model_out_text)
|
||||||
print("clean prompt response:", clean_str)
|
print("clean prompt response:", clean_str)
|
||||||
response = json.loads(clean_str)
|
response = json.loads(clean_str)
|
||||||
sql, thoughts = response["sql"], response["thoughts"]
|
sql, thoughts = response["sql"], response["thoughts"]
|
||||||
|
@ -62,4 +62,3 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
@ -19,13 +19,17 @@ class ChatWithDbQA(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
|
def __init__(
|
||||||
|
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
||||||
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.ChatWithDbQA,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.ChatWithDbQA,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
if db_name:
|
if db_name:
|
||||||
self.database = CFG.local_db
|
self.database = CFG.local_db
|
||||||
@ -34,17 +38,16 @@ class ChatWithDbQA(BaseChat):
|
|||||||
self.top_k: int = 5
|
self.top_k: int = 5
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
|
|
||||||
table_info = ""
|
table_info = ""
|
||||||
dialect = "mysql"
|
dialect = "mysql"
|
||||||
try:
|
try:
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
"Could not import DBSummaryClient. "
|
|
||||||
)
|
|
||||||
if self.db_name:
|
if self.db_name:
|
||||||
table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
table_info = DBSummaryClient.get_similar_tables(
|
||||||
|
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||||
|
)
|
||||||
# table_info = self.database.table_simple_info(self.db_connect)
|
# table_info = self.database.table_simple_info(self.db_connect)
|
||||||
dialect = self.database.dialect
|
dialect = self.database.dialect
|
||||||
|
|
||||||
@ -52,7 +55,7 @@ class ChatWithDbQA(BaseChat):
|
|||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
"top_k": str(self.top_k),
|
"top_k": str(self.top_k),
|
||||||
"dialect": dialect,
|
"dialect": dialect,
|
||||||
"table_info": table_info
|
"table_info": table_info,
|
||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
|
||||||
|
|
||||||
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ Pay attention to use only the column names that you can see in the schema descri
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
@ -37,7 +36,7 @@ prompt = PromptTemplate(
|
|||||||
input_variables=["input", "table_info", "dialect", "top_k"],
|
input_variables=["input", "table_info", "dialect", "top_k"],
|
||||||
response_format=None,
|
response_format=None,
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX ,
|
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
output_parser=NormalChatOutputParser(
|
output_parser=NormalChatOutputParser(
|
||||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
@ -45,4 +44,3 @@ prompt = PromptTemplate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
@ -14,56 +14,72 @@ from pilot.scene.chat_execution.prompt import prompt
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class ChatWithPlugin(BaseChat):
|
class ChatWithPlugin(BaseChat):
|
||||||
chat_scene: str = ChatScene.ChatExecution.value
|
chat_scene: str = ChatScene.ChatExecution.value
|
||||||
plugins_prompt_generator:PluginPromptGenerator
|
plugins_prompt_generator: PluginPromptGenerator
|
||||||
select_plugin: str = None
|
select_plugin: str = None
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None):
|
def __init__(
|
||||||
super().__init__(temperature=temperature,
|
self,
|
||||||
max_new_tokens=max_new_tokens,
|
temperature,
|
||||||
chat_mode=ChatScene.ChatExecution,
|
max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id,
|
||||||
current_user_input=user_input)
|
user_input,
|
||||||
|
plugin_selector: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
chat_mode=ChatScene.ChatExecution,
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||||
# 加载插件中可用命令
|
# 加载插件中可用命令
|
||||||
self.select_plugin = plugin_selector
|
self.select_plugin = plugin_selector
|
||||||
if self.select_plugin:
|
if self.select_plugin:
|
||||||
for plugin in CFG.plugins:
|
for plugin in CFG.plugins:
|
||||||
if plugin._name == plugin_selector :
|
if plugin._name == plugin_selector:
|
||||||
if not plugin.can_handle_post_prompt():
|
if not plugin.can_handle_post_prompt():
|
||||||
continue
|
continue
|
||||||
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
|
self.plugins_prompt_generator = plugin.post_prompt(
|
||||||
|
self.plugins_prompt_generator
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for plugin in CFG.plugins:
|
for plugin in CFG.plugins:
|
||||||
if not plugin.can_handle_post_prompt():
|
if not plugin.can_handle_post_prompt():
|
||||||
continue
|
continue
|
||||||
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
|
self.plugins_prompt_generator = plugin.post_prompt(
|
||||||
|
self.plugins_prompt_generator
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)),
|
"constraints": self.__list_to_prompt_str(
|
||||||
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
|
list(self.plugins_prompt_generator.constraints)
|
||||||
|
),
|
||||||
|
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
|
||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
## plugin command run
|
## plugin command run
|
||||||
return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator)
|
return execute_command(
|
||||||
|
str(prompt_response.command.get("name")),
|
||||||
|
prompt_response.command.get("args", {}),
|
||||||
|
self.plugins_prompt_generator,
|
||||||
|
)
|
||||||
|
|
||||||
def chat_show(self):
|
def chat_show(self):
|
||||||
super().chat_show()
|
super().chat_show()
|
||||||
|
|
||||||
|
|
||||||
def __list_to_prompt_str(self, list: List) -> str:
|
def __list_to_prompt_str(self, list: List) -> str:
|
||||||
if list:
|
if list:
|
||||||
separator = '\n'
|
separator = "\n"
|
||||||
return separator.join(list)
|
return separator.join(list)
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
@ -10,14 +10,13 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
|
||||||
class PluginAction(NamedTuple):
|
class PluginAction(NamedTuple):
|
||||||
command: Dict
|
command: Dict
|
||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PluginChatOutputParser(BaseOutputParser):
|
class PluginChatOutputParser(BaseOutputParser):
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||||
command, thoughts = response["command"], response["thoughts"]
|
command, thoughts = response["command"], response["thoughts"]
|
||||||
@ -25,7 +24,7 @@ class PluginChatOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data) -> str:
|
||||||
### tool out data to table view
|
### tool out data to table view
|
||||||
print(f"parse_view_response:{speak},{str(data)}" )
|
print(f"parse_view_response:{speak},{str(data)}")
|
||||||
view_text = f"##### {speak}" + "\n" + str(data)
|
view_text = f"##### {speak}" + "\n" + str(data)
|
||||||
return view_text
|
return view_text
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
|
|||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatExecution.value,
|
template_scene=ChatScene.ChatExecution.value,
|
||||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||||
|
@ -11,6 +11,7 @@ from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
|||||||
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
||||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||||
|
|
||||||
|
|
||||||
class ChatFactory(metaclass=Singleton):
|
class ChatFactory(metaclass=Singleton):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_implementation(chat_mode, **kwargs):
|
def get_implementation(chat_mode, **kwargs):
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -24,18 +23,22 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class ChatNewKnowledge (BaseChat):
|
class ChatNewKnowledge(BaseChat):
|
||||||
chat_scene: str = ChatScene.ChatNewKnowledge.value
|
chat_scene: str = ChatScene.ChatNewKnowledge.value
|
||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name):
|
def __init__(
|
||||||
|
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name
|
||||||
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.ChatNewKnowledge,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.ChatNewKnowledge,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
self.knowledge_name = knowledge_name
|
self.knowledge_name = knowledge_name
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": knowledge_name,
|
"vector_store_name": knowledge_name,
|
||||||
@ -49,21 +52,17 @@ class ChatNewKnowledge (BaseChat):
|
|||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
|
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||||
|
)
|
||||||
docs = docs[:2000]
|
docs = docs[:2000]
|
||||||
input_values = {
|
input_values = {"context": docs, "question": self.current_user_input}
|
||||||
"context": docs,
|
|
||||||
"question": self.current_user_input
|
|
||||||
}
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatNewKnowledge.value
|
return ChatScene.ChatNewKnowledge.value
|
||||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
|
||||||
|
|
||||||
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
@ -42,5 +41,3 @@ prompt = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -24,43 +23,42 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class ChatDefaultKnowledge (BaseChat):
|
class ChatDefaultKnowledge(BaseChat):
|
||||||
chat_scene: str = ChatScene.ChatKnowledge.value
|
chat_scene: str = ChatScene.ChatKnowledge.value
|
||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
|
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.ChatKnowledge,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.ChatKnowledge,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": "default",
|
"vector_store_name": "default",
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
file_path="",
|
file_path="",
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
local_persist=False,
|
local_persist=False,
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
|
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||||
|
)
|
||||||
docs = docs[:2000]
|
docs = docs[:2000]
|
||||||
input_values = {
|
input_values = {"context": docs, "question": self.current_user_input}
|
||||||
"context": docs,
|
|
||||||
"question": self.current_user_input
|
|
||||||
}
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatKnowledge.value
|
return ChatScene.ChatKnowledge.value
|
||||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
|
||||||
|
|
||||||
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
@ -39,5 +38,3 @@ prompt = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -8,34 +7,42 @@ from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class InnerChatDBSummary (BaseChat):
|
class InnerChatDBSummary(BaseChat):
|
||||||
chat_scene: str = ChatScene.InnerChatDBSummary.value
|
chat_scene: str = ChatScene.InnerChatDBSummary.value
|
||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary):
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature,
|
||||||
|
max_new_tokens,
|
||||||
|
chat_session_id,
|
||||||
|
user_input,
|
||||||
|
db_select,
|
||||||
|
db_summary,
|
||||||
|
):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.InnerChatDBSummary,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.InnerChatDBSummary,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
self.db_name = db_select
|
current_user_input=user_input,
|
||||||
self.db_summary = db_summary
|
)
|
||||||
|
|
||||||
|
self.db_input = db_select
|
||||||
|
self.db_summary = db_summary
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"db_input": self.db_name,
|
"db_input": self.db_input,
|
||||||
"db_profile_summary": self.db_summary
|
"db_profile_summary": self.db_summary,
|
||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.InnerChatDBSummary.value
|
return ChatScene.InnerChatDBSummary.value
|
||||||
|
@ -10,13 +10,15 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
return model_out_text
|
def parse_prompt_response(self, model_out_text):
|
||||||
|
clean_str = super().parse_prompt_response(model_out_text)
|
||||||
|
print("clean prompt response:", clean_str)
|
||||||
|
return clean_str
|
||||||
|
|
||||||
def parse_view_response(self, ai_text, data) -> str:
|
def parse_view_response(self, ai_text, data) -> str:
|
||||||
return ai_text["table"]
|
return ai_text
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
@ -7,33 +7,30 @@ from pilot.configs.config import Config
|
|||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser
|
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import (
|
||||||
|
NormalChatOutputParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE =""""""
|
PROMPT_SCENE_DEFINE = """"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
Based on the following known database information?, answer which tables are involved in the user input.
|
Based on the following known database information?, answer which tables are involved in the user input.
|
||||||
Known database information:{db_profile_summary}
|
Known database information:{db_profile_summary}
|
||||||
Input:{db_input}
|
Input:{db_input}
|
||||||
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
|
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
|
||||||
The response format must be JSON, and the key of JSON must be "table".
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||||
{response}
|
{response}
|
||||||
|
The response format must be JSON, and the key of JSON must be "table".
|
||||||
Ensure the response is correct json and can be parsed by Python json.loads
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = {"table": ["orders", "products"]}
|
||||||
RESPONSE_FORMAT = {
|
|
||||||
"table": ["orders", "products"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
@ -54,5 +51,3 @@ prompt = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -24,18 +23,20 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class ChatUrlKnowledge (BaseChat):
|
class ChatUrlKnowledge(BaseChat):
|
||||||
chat_scene: str = ChatScene.ChatUrlKnowledge.value
|
chat_scene: str = ChatScene.ChatUrlKnowledge.value
|
||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url):
|
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.ChatUrlKnowledge,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.ChatUrlKnowledge,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
self.url = url
|
self.url = url
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": url,
|
"vector_store_name": url,
|
||||||
@ -54,19 +55,16 @@ class ChatUrlKnowledge (BaseChat):
|
|||||||
self.knowledge_embedding_client.knowledge_embedding()
|
self.knowledge_embedding_client.knowledge_embedding()
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
|
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||||
|
)
|
||||||
docs = docs[:2000]
|
docs = docs[:2000]
|
||||||
input_values = {
|
input_values = {"context": docs, "question": self.current_user_input}
|
||||||
"context": docs,
|
|
||||||
"question": self.current_user_input
|
|
||||||
}
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatUrlKnowledge.value
|
return ChatScene.ChatUrlKnowledge.value
|
||||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
|
||||||
|
|
||||||
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
@ -39,5 +38,3 @@ prompt = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -19,25 +18,23 @@ class ChatNormal(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
|
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
||||||
""" """
|
""" """
|
||||||
super().__init__(temperature=temperature,
|
super().__init__(
|
||||||
max_new_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
chat_mode=ChatScene.ChatNormal,
|
max_new_tokens=max_new_tokens,
|
||||||
chat_session_id=chat_session_id,
|
chat_mode=ChatScene.ChatNormal,
|
||||||
current_user_input=user_input)
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {"input": self.current_user_input}
|
||||||
"input": self.current_user_input
|
|
||||||
}
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatNormal.value
|
return ChatScene.ChatNormal.value
|
||||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
|||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
|
||||||
|
|
||||||
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
|
@ -29,5 +29,3 @@ prompt = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,9 +85,7 @@ add_knowledge_base_dialogue = get_lang_text(
|
|||||||
"knowledge_qa_type_add_knowledge_base_dialogue"
|
"knowledge_qa_type_add_knowledge_base_dialogue"
|
||||||
)
|
)
|
||||||
|
|
||||||
url_knowledge_dialogue = get_lang_text(
|
url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue")
|
||||||
"knowledge_qa_type_url_knowledge_dialogue"
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_qa_type_list = [
|
knowledge_qa_type_list = [
|
||||||
llm_native_dialogue,
|
llm_native_dialogue,
|
||||||
@ -205,9 +203,9 @@ def post_process_code(code):
|
|||||||
|
|
||||||
|
|
||||||
def get_chat_mode(selected, param=None) -> ChatScene:
|
def get_chat_mode(selected, param=None) -> ChatScene:
|
||||||
if chat_mode_title['chat_use_plugin'] == selected:
|
if chat_mode_title["chat_use_plugin"] == selected:
|
||||||
return ChatScene.ChatExecution
|
return ChatScene.ChatExecution
|
||||||
elif chat_mode_title['knowledge_qa'] == selected:
|
elif chat_mode_title["knowledge_qa"] == selected:
|
||||||
mode = param
|
mode = param
|
||||||
if mode == conversation_types["default_knownledge"]:
|
if mode == conversation_types["default_knownledge"]:
|
||||||
return ChatScene.ChatKnowledge
|
return ChatScene.ChatKnowledge
|
||||||
@ -232,14 +230,23 @@ def chatbot_callback(state, message):
|
|||||||
|
|
||||||
|
|
||||||
def http_bot(
|
def http_bot(
|
||||||
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input,
|
state,
|
||||||
knowledge_name
|
selected,
|
||||||
|
temperature,
|
||||||
|
max_new_tokens,
|
||||||
|
plugin_selector,
|
||||||
|
mode,
|
||||||
|
sql_mode,
|
||||||
|
db_selector,
|
||||||
|
url_input,
|
||||||
|
knowledge_name,
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
|
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
|
||||||
if chat_mode_title['knowledge_qa'] == selected:
|
)
|
||||||
|
if chat_mode_title["knowledge_qa"] == selected:
|
||||||
scene: ChatScene = get_chat_mode(selected, mode)
|
scene: ChatScene = get_chat_mode(selected, mode)
|
||||||
elif chat_mode_title['chat_use_plugin'] == selected:
|
elif chat_mode_title["chat_use_plugin"] == selected:
|
||||||
scene: ChatScene = get_chat_mode(selected)
|
scene: ChatScene = get_chat_mode(selected)
|
||||||
else:
|
else:
|
||||||
scene: ChatScene = get_chat_mode(selected, sql_mode)
|
scene: ChatScene = get_chat_mode(selected, sql_mode)
|
||||||
@ -251,7 +258,7 @@ def http_bot(
|
|||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"db_name": db_selector,
|
"db_name": db_selector,
|
||||||
"user_input": state.last_user_input
|
"user_input": state.last_user_input,
|
||||||
}
|
}
|
||||||
elif ChatScene.ChatWithDbQA == scene:
|
elif ChatScene.ChatWithDbQA == scene:
|
||||||
chat_param = {
|
chat_param = {
|
||||||
@ -289,7 +296,7 @@ def http_bot(
|
|||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"user_input": state.last_user_input,
|
"user_input": state.last_user_input,
|
||||||
"knowledge_name": knowledge_name
|
"knowledge_name": knowledge_name,
|
||||||
}
|
}
|
||||||
elif ChatScene.ChatUrlKnowledge == scene:
|
elif ChatScene.ChatUrlKnowledge == scene:
|
||||||
chat_param = {
|
chat_param = {
|
||||||
@ -297,7 +304,7 @@ def http_bot(
|
|||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"user_input": state.last_user_input,
|
"user_input": state.last_user_input,
|
||||||
"url": url_input
|
"url": url_input,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
|
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
|
||||||
@ -314,7 +321,11 @@ def http_bot(
|
|||||||
response = chat.stream_call()
|
response = chat.stream_call()
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len)
|
state.messages[-1][
|
||||||
|
-1
|
||||||
|
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
|
chunk, chat.skip_echo_len
|
||||||
|
)
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -323,8 +334,8 @@ def http_bot(
|
|||||||
|
|
||||||
|
|
||||||
block_css = (
|
block_css = (
|
||||||
code_highlight_css
|
code_highlight_css
|
||||||
+ """
|
+ """
|
||||||
pre {
|
pre {
|
||||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||||
@ -361,7 +372,7 @@ def build_single_model_ui():
|
|||||||
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||||
|
|
||||||
with gr.Accordion(
|
with gr.Accordion(
|
||||||
get_lang_text("model_control_param"), open=False, visible=False
|
get_lang_text("model_control_param"), open=False, visible=False
|
||||||
) as parameter_row:
|
) as parameter_row:
|
||||||
temperature = gr.Slider(
|
temperature = gr.Slider(
|
||||||
minimum=0.0,
|
minimum=0.0,
|
||||||
@ -411,7 +422,7 @@ def build_single_model_ui():
|
|||||||
get_lang_text("sql_generate_mode_none"),
|
get_lang_text("sql_generate_mode_none"),
|
||||||
],
|
],
|
||||||
show_label=False,
|
show_label=False,
|
||||||
value=get_lang_text("sql_generate_mode_none")
|
value=get_lang_text("sql_generate_mode_none"),
|
||||||
)
|
)
|
||||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
||||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||||
@ -428,15 +439,19 @@ def build_single_model_ui():
|
|||||||
value="",
|
value="",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
show_label=True,
|
show_label=True,
|
||||||
type="value"
|
type="value",
|
||||||
).style(container=False)
|
).style(container=False)
|
||||||
|
|
||||||
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
|
def plugin_change(
|
||||||
|
evt: gr.SelectData,
|
||||||
|
): # SelectData is a subclass of EventData
|
||||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||||
print(f"user plugin:{plugins_select_info().get(evt.value)}")
|
print(f"user plugin:{plugins_select_info().get(evt.value)}")
|
||||||
return plugins_select_info().get(evt.value)
|
return plugins_select_info().get(evt.value)
|
||||||
|
|
||||||
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
plugin_selected = gr.Textbox(
|
||||||
|
show_label=False, visible=False, placeholder="Selected"
|
||||||
|
)
|
||||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
plugin_selector.select(plugin_change, None, plugin_selected)
|
||||||
|
|
||||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
||||||
@ -456,7 +471,12 @@ def build_single_model_ui():
|
|||||||
)
|
)
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
|
|
||||||
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
|
url_input = gr.Textbox(
|
||||||
|
label=get_lang_text("url_input_label"),
|
||||||
|
lines=1,
|
||||||
|
interactive=True,
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
|
||||||
def show_url_input(evt: gr.SelectData):
|
def show_url_input(evt: gr.SelectData):
|
||||||
if evt.value == url_knowledge_dialogue:
|
if evt.value == url_knowledge_dialogue:
|
||||||
@ -559,10 +579,10 @@ def build_single_model_ui():
|
|||||||
|
|
||||||
def build_webdemo():
|
def build_webdemo():
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
title=get_lang_text("database_smart_assistant"),
|
title=get_lang_text("database_smart_assistant"),
|
||||||
# theme=gr.themes.Base(),
|
# theme=gr.themes.Base(),
|
||||||
theme=gr.themes.Default(),
|
theme=gr.themes.Default(),
|
||||||
css=block_css,
|
css=block_css,
|
||||||
) as demo:
|
) as demo:
|
||||||
url_params = gr.JSON(visible=False)
|
url_params = gr.JSON(visible=False)
|
||||||
(
|
(
|
||||||
|
@ -18,7 +18,14 @@ CFG = Config()
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeEmbedding:
|
class KnowledgeEmbedding:
|
||||||
def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path,
|
||||||
|
model_name,
|
||||||
|
vector_store_config,
|
||||||
|
local_persist=True,
|
||||||
|
file_type="default",
|
||||||
|
):
|
||||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -63,7 +70,6 @@ class KnowledgeEmbedding:
|
|||||||
vector_store_config=self.vector_store_config,
|
vector_store_config=self.vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
elif self.file_type == "default":
|
elif self.file_type == "default":
|
||||||
embedding = MarkdownEmbedding(
|
embedding = MarkdownEmbedding(
|
||||||
file_path=self.file_path,
|
file_path=self.file_path,
|
||||||
@ -71,7 +77,6 @@ class KnowledgeEmbedding:
|
|||||||
vector_store_config=self.vector_store_config,
|
vector_store_config=self.vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def similar_search(self, text, topk):
|
def similar_search(self, text, topk):
|
||||||
|
@ -13,6 +13,7 @@ from pilot.summary.mysql_db_summary import MysqlSummary
|
|||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
chat_factory = ChatFactory()
|
||||||
|
|
||||||
|
|
||||||
class DBSummaryClient:
|
class DBSummaryClient:
|
||||||
@ -88,13 +89,18 @@ class DBSummaryClient:
|
|||||||
)
|
)
|
||||||
if CFG.SUMMARY_CONFIG == "FAST":
|
if CFG.SUMMARY_CONFIG == "FAST":
|
||||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||||
related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs]
|
related_tables = [
|
||||||
|
json.loads(table_doc.page_content)["table_name"]
|
||||||
|
for table_doc in table_docs
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
table_docs = knowledge_embedding_client.similar_search(query, 1)
|
table_docs = knowledge_embedding_client.similar_search(query, 1)
|
||||||
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
|
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
|
||||||
# query, table_docs[0].page_content
|
# query, table_docs[0].page_content
|
||||||
# )
|
# )
|
||||||
related_tables = _get_llm_response(query, dbname, table_docs[0].page_content)
|
related_tables = _get_llm_response(
|
||||||
|
query, dbname, table_docs[0].page_content
|
||||||
|
)
|
||||||
related_table_summaries = []
|
related_table_summaries = []
|
||||||
for table in related_tables:
|
for table in related_tables:
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
@ -118,35 +124,14 @@ def _get_llm_response(query, db_input, dbsummary):
|
|||||||
"max_new_tokens": 512,
|
"max_new_tokens": 512,
|
||||||
"chat_session_id": uuid.uuid1(),
|
"chat_session_id": uuid.uuid1(),
|
||||||
"user_input": query,
|
"user_input": query,
|
||||||
"db_input": db_input,
|
"db_select": db_input,
|
||||||
"db_summary": dbsummary,
|
"db_summary": dbsummary,
|
||||||
}
|
}
|
||||||
chat_factory = ChatFactory()
|
chat: BaseChat = chat_factory.get_implementation(
|
||||||
chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param)
|
ChatScene.InnerChatDBSummary.value, **chat_param
|
||||||
|
)
|
||||||
return chat.call()
|
res = chat.nostream_call()
|
||||||
# payload = {
|
return json.loads(res)["table"]
|
||||||
# "model": CFG.LLM_MODEL,
|
|
||||||
# "prompt": prompt,
|
|
||||||
# "temperature": float(0.7),
|
|
||||||
# "max_new_tokens": int(512),
|
|
||||||
# "stop": state.sep
|
|
||||||
# if state.sep_style == SeparatorStyle.SINGLE
|
|
||||||
# else state.sep2,
|
|
||||||
# }
|
|
||||||
# headers = {"User-Agent": "dbgpt Client"}
|
|
||||||
# response = requests.post(
|
|
||||||
# urljoin(CFG.MODEL_SERVER, "generate"),
|
|
||||||
# headers=headers,
|
|
||||||
# json=payload,
|
|
||||||
# timeout=120,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# print(related_tables)
|
|
||||||
# return related_tables
|
|
||||||
# except NotCommands as e:
|
|
||||||
# print("llm response error:" + e.message)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
@ -37,20 +37,23 @@ class MysqlSummary(DBSummary):
|
|||||||
table_name=table_comment[0], table_comment=table_comment[1]
|
table_name=table_comment[0], table_comment=table_comment[1]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_table = json.dumps(
|
vector_table = json.dumps(
|
||||||
{"table_name": table_comment[0], "table_description": table_comment[1]}
|
{"table_name": table_comment[0], "table_description": table_comment[1]}
|
||||||
)
|
)
|
||||||
self.vector_tables_info.append(
|
self.vector_tables_info.append(
|
||||||
vector_table.encode("utf-8").decode("unicode_escape")
|
vector_table.encode("utf-8").decode("unicode_escape")
|
||||||
)
|
)
|
||||||
|
self.table_columns_info = []
|
||||||
for table_name in tables:
|
for table_name in tables:
|
||||||
table_summary = MysqlTableSummary(self.db, name, table_name)
|
table_summary = MysqlTableSummary(self.db, name, table_name)
|
||||||
self.tables[table_name] = table_summary.get_summery()
|
# self.tables[table_name] = table_summary.get_summery()
|
||||||
|
self.tables[table_name] = table_summary.get_columns()
|
||||||
|
self.table_columns_info.append(table_summary.get_columns())
|
||||||
# self.tables_info.append(table_summary.get_summery())
|
# self.tables_info.append(table_summary.get_summery())
|
||||||
|
|
||||||
def get_summery(self):
|
def get_summery(self):
|
||||||
if CFG.SUMMARY_CONFIG == "VECTOR":
|
if CFG.SUMMARY_CONFIG == "FAST":
|
||||||
return self.vector_tables_info
|
return self.vector_tables_info
|
||||||
else:
|
else:
|
||||||
return self.summery.format(
|
return self.summery.format(
|
||||||
@ -63,6 +66,9 @@ class MysqlSummary(DBSummary):
|
|||||||
def get_table_comments(self):
|
def get_table_comments(self):
|
||||||
return self.table_comments
|
return self.table_comments
|
||||||
|
|
||||||
|
def get_columns(self):
|
||||||
|
return self.table_columns_info
|
||||||
|
|
||||||
|
|
||||||
class MysqlTableSummary(TableSummary):
|
class MysqlTableSummary(TableSummary):
|
||||||
"""Get mysql table summary template."""
|
"""Get mysql table summary template."""
|
||||||
@ -78,10 +84,16 @@ class MysqlTableSummary(TableSummary):
|
|||||||
self.db = instance
|
self.db = instance
|
||||||
fields = self.db.get_fields(name)
|
fields = self.db.get_fields(name)
|
||||||
indexes = self.db.get_indexes(name)
|
indexes = self.db.get_indexes(name)
|
||||||
|
field_names = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
field_summary = MysqlFieldsSummary(field)
|
field_summary = MysqlFieldsSummary(field)
|
||||||
self.fields.append(field_summary)
|
self.fields.append(field_summary)
|
||||||
self.fields_info.append(field_summary.get_summery())
|
self.fields_info.append(field_summary.get_summery())
|
||||||
|
field_names.append(field[0])
|
||||||
|
|
||||||
|
self.column_summery = """{name}({columns_info})""".format(
|
||||||
|
name=name, columns_info=",".join(field_names)
|
||||||
|
)
|
||||||
|
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
index_summary = MysqlIndexSummary(index)
|
index_summary = MysqlIndexSummary(index)
|
||||||
@ -96,6 +108,9 @@ class MysqlTableSummary(TableSummary):
|
|||||||
indexes=";".join(self.indexes_info),
|
indexes=";".join(self.indexes_info),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_columns(self):
|
||||||
|
return self.column_summery
|
||||||
|
|
||||||
|
|
||||||
class MysqlFieldsSummary(FieldSummary):
|
class MysqlFieldsSummary(FieldSummary):
|
||||||
"""Get mysql field summary template."""
|
"""Get mysql field summary template."""
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from pilot.vector_store.chroma_store import ChromaStore
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
|
||||||
# from pilot.vector_store.milvus_store import MilvusStore
|
# from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
|
||||||
connector = {"Chroma": ChromaStore, "Milvus": None}
|
connector = {"Chroma": ChromaStore, "Milvus": None}
|
||||||
|
Loading…
Reference in New Issue
Block a user