From 06bc4452d464277033f09f87d426b139c4294288 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Wed, 31 May 2023 15:59:50 +0800 Subject: [PATCH] Implemented a new multi-scenario dialogue architecture --- pilot/common/sql_database.py | 1 + pilot/conversation.py | 28 +- pilot/model/proxy_llm.py | 52 ++- pilot/out_parser/base.py | 5 +- pilot/prompts/prompt_new.py | 56 +-- pilot/scene/base.py | 4 +- pilot/scene/base_chat.py | 99 ++---- pilot/scene/chat_db/auto_execute/__init__.py | 0 pilot/scene/chat_db/auto_execute/chat.py | 57 +++ .../chat_db/{ => auto_execute}/out_parser.py | 4 +- .../chat_db/{ => auto_execute}/prompt.py | 27 +- pilot/scene/chat_db/chat.py | 240 ------------- .../scene/chat_db/professional_qa/__init__.py | 0 pilot/scene/chat_db/professional_qa/chat.py | 56 +++ .../chat_db/professional_qa/out_parser.py | 22 ++ pilot/scene/chat_db/professional_qa/prompt.py | 48 +++ pilot/scene/chat_execution/chat.py | 11 +- pilot/scene/chat_execution/out_parser.py | 4 +- pilot/scene/chat_execution/prompt.py | 5 +- .../chat_execution/prompt_with_command.py | 65 ---- pilot/scene/chat_factory.py | 12 +- pilot/scene/chat_knowledge/custom/chat.py | 69 ++++ .../scene/chat_knowledge/custom/out_parser.py | 22 ++ pilot/scene/chat_knowledge/custom/prompt.py | 43 +++ pilot/scene/chat_knowledge/default/chat.py | 66 ++++ .../chat_knowledge/default/out_parser.py | 22 ++ pilot/scene/chat_knowledge/default/prompt.py | 43 +++ pilot/scene/chat_knowledge/url/chat.py | 71 ++++ pilot/scene/chat_knowledge/url/out_parser.py | 22 ++ pilot/scene/chat_knowledge/url/prompt.py | 43 +++ pilot/scene/chat_normal/chat.py | 43 +++ pilot/scene/chat_normal/out_parser.py | 22 ++ pilot/scene/chat_normal/prompt.py | 50 +-- pilot/server/webserver.py | 335 ++++++------------ requirements.txt | 1 + 35 files changed, 905 insertions(+), 743 deletions(-) create mode 100644 pilot/scene/chat_db/auto_execute/__init__.py create mode 100644 pilot/scene/chat_db/auto_execute/chat.py rename pilot/scene/chat_db/{ => auto_execute}/out_parser.py (90%) rename pilot/scene/chat_db/{ => auto_execute}/prompt.py (63%) delete mode 100644 pilot/scene/chat_db/chat.py create mode 100644 pilot/scene/chat_db/professional_qa/__init__.py create mode 100644 pilot/scene/chat_db/professional_qa/chat.py create mode 100644 pilot/scene/chat_db/professional_qa/out_parser.py create mode 100644 pilot/scene/chat_db/professional_qa/prompt.py delete mode 100644 pilot/scene/chat_execution/prompt_with_command.py create mode 100644 pilot/scene/chat_knowledge/custom/chat.py create mode 100644 pilot/scene/chat_knowledge/custom/out_parser.py create mode 100644 pilot/scene/chat_knowledge/custom/prompt.py create mode 100644 pilot/scene/chat_knowledge/default/chat.py create mode 100644 pilot/scene/chat_knowledge/default/out_parser.py create mode 100644 pilot/scene/chat_knowledge/default/prompt.py create mode 100644 pilot/scene/chat_knowledge/url/chat.py create mode 100644 pilot/scene/chat_knowledge/url/out_parser.py create mode 100644 pilot/scene/chat_knowledge/url/prompt.py create mode 100644 pilot/scene/chat_normal/chat.py create mode 100644 pilot/scene/chat_normal/out_parser.py diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 2b8d6fe4b..c3ac5bdc6 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -277,6 +277,7 @@ class Database: def run(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" + print("sql run:" + command) cursor = session.execute(text(command)) if cursor.returns_rows: if fetch == "all": diff --git a/pilot/conversation.py b/pilot/conversation.py index ee07adf33..8a758dd51 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -105,18 +105,14 @@ class Conversation: } -def gen_sqlgen_conversation(dbname): - from pilot.connections.mysql import MySQLOperator - - mo = MySQLOperator(**(DB_SETTINGS)) - - message = "" - - schemas = mo.get_schema(dbname) - for s in schemas: - message += s["schema_info"] + ";" - return f"Database {dbname} Schema information as follows: {message}\n" - +conv_default = Conversation( + system = None, + roles=("human", "ai"), + messages= (), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) conv_one_shot = Conversation( system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " @@ -261,7 +257,7 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回 # question: # {question} # """ -default_conversation = conv_one_shot +default_conversation = conv_default chat_mode_title = { @@ -289,8 +285,4 @@ conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1, "auto_dbgpt_one_shot": auto_dbgpt_one_shot, -} - -if __name__ == "__main__": - message = gen_sqlgen_conversation("dbgpt") - print(message) +} \ No newline at end of file diff --git a/pilot/model/proxy_llm.py b/pilot/model/proxy_llm.py index 3242603d3..92887cfc6 100644 --- a/pilot/model/proxy_llm.py +++ b/pilot/model/proxy_llm.py @@ -21,22 +21,46 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) } messages = prompt.split(stop) - # Add history conversation - for i in range(1, len(messages) - 2, 2): - history.append( - {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, - ) - history.append( - { - "role": "system", - "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1], - } - ) + for message in messages: + if len(message) <= 0: + continue + if "human:" in message: + history.append( + {"role": "user", "content": message.split("human:")[1]}, + ) + elif "system:" in message: + history.append( + { + "role": "system", + "content": message.split("system:")[1], + } + ) + elif "ai:" in message: + history.append( + { + "role": "ai", + "content": message.split("ai:")[1], + } + ) + else: + history.append( + { + "role": "system", + "content": message, + } + ) + + # 把最后一个用户的信息移动到末尾 + temp_his = history[::-1] + last_user_input = None + for m in temp_his: + if m["role"] == "user": + last_user_input = m + if last_user_input: + history.remove(last_user_input) + history.append(last_user_input) - # Add user query - query = messages[-2].split(ROLE_USER + ":")[1] - history.append({"role": "user", "content": query}) payloads = { "model": "gpt-3.5-turbo", # just for test, remove this later "messages": history, diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 57f0a7f7e..fee3eda37 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -36,7 +36,7 @@ class BaseOutputParser(ABC): self.sep = sep self.is_stream_out = is_stream_out - def __post_process_code(code): + def __post_process_code(self, code): sep = "\n```" if sep in code: blocks = code.split(sep) @@ -92,7 +92,7 @@ class BaseOutputParser(ABC): ai_response = ai_response.replace("\n", "") ai_response = ai_response.replace("\_", "_") ai_response = ai_response.replace("\*", "*") - print("un_stream clear response:{}", ai_response) + print("un_stream ai response:", ai_response) return ai_response else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) @@ -140,6 +140,7 @@ class BaseOutputParser(ABC): cleaned_output = m.group(0) else: raise ValueError("model server out not fllow the prompt!") + cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') return cleaned_output def parse_view_response(self, ai_text) -> str: diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 389b1a33e..888b6f81e 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -31,15 +31,15 @@ DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { class PromptTemplate(BaseModel, ABC): input_variables: List[str] """A list of the names of the variables the prompt template expects.""" - template_scene: str + template_scene: Optional[str] - template_define: str + template_define: Optional[str] """this template define""" - template: str + template: Optional[str] """The prompt template.""" template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - response_format: str + response_format: Optional[str] """default use stream out""" stream_out: bool = True """""" @@ -57,52 +57,12 @@ class PromptTemplate(BaseModel, ABC): """Return the prompt type key.""" return "prompt" - def _generate_command_string(self, command: Dict[str, Any]) -> str: - """ - Generate a formatted string representation of a command. - - Args: - command (dict): A dictionary containing command information. - - Returns: - str: The formatted command string. - """ - args_string = ", ".join( - f'"{key}": "{value}"' for key, value in command["args"].items() - ) - return f'{command["label"]}: "{command["name"]}", args: {args_string}' - - def _generate_numbered_list(self, items: List[Any], item_type="list") -> str: - """ - Generate a numbered list from given items based on the item_type. - - Args: - items (list): A list of items to be numbered. - item_type (str, optional): The type of items in the list. - Defaults to 'list'. - - Returns: - str: The formatted numbered list. - """ - if item_type == "command": - command_strings = [] - if self.command_registry: - command_strings += [ - str(item) - for item in self.command_registry.commands.values() - if item.enabled - ] - # terminate command is added manually - command_strings += [self._generate_command_string(item) for item in items] - return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) - else: - return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) - def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs.""" - - kwargs["response"] = json.dumps(self.response_format, indent=4) - return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + if self.template: + if self.response_format: + kwargs["response"] = json.dumps(self.response_format, indent=4) + return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) def add_goals(self, goal: str) -> None: self.goals.append(goal) diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 9fcc6fb31..21f605fed 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -2,8 +2,10 @@ from enum import Enum class ChatScene(Enum): - ChatWithDb = "chat_with_db" + ChatWithDbExecute = "chat_with_db_execute" + ChatWithDbQA = "chat_with_db_qa" ChatExecution = "chat_execution" ChatKnowledge = "chat_default_knowledge" ChatNewKnowledge = "chat_new_knowledge" + ChatUrlKnowledge = "chat_url_knowledge" ChatNormal = "chat_normal" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 1398a476a..798e071e3 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -56,7 +56,7 @@ class BaseChat(ABC): arbitrary_types_allowed = True - def __init__(self, chat_mode, chat_session_id, current_user_input): + def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input): self.chat_session_id = chat_session_id self.chat_mode = chat_mode self.current_user_input: str = current_user_input @@ -64,12 +64,12 @@ class BaseChat(ABC): ### TODO self.memory = FileHistoryMemory(chat_session_id) ### load prompt template - self.prompt_template: PromptTemplate = CFG.prompt_templates[ - self.chat_mode.value - ] + self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] self.history_message: List[OnceConversation] = [] self.current_message: OnceConversation = OnceConversation() self.current_tokens_used: int = 0 + self.temperature = temperature + self.max_new_tokens = max_new_tokens ### load chat_session_id's chat historys self._load_history(self.chat_session_id) @@ -92,15 +92,17 @@ class BaseChat(ABC): pass def __call_base(self): - input_values = self.generate_input_values() + input_values = self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) self.current_message.start_date = datetime.datetime.now() # TODO self.current_message.tokens = 0 + current_prompt = None - current_prompt = self.prompt_template.format(**input_values) + if self.prompt_template.template: + current_prompt = self.prompt_template.format(**input_values) ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 if self.history_message: @@ -108,8 +110,8 @@ class BaseChat(ABC): logger.info( f"There are already {len(self.history_message)} rounds of conversations!" ) - - self.current_message.add_system_message(current_prompt) + if current_prompt: + self.current_message.add_system_message(current_prompt) payload = { "model": self.llm_model, @@ -118,7 +120,6 @@ class BaseChat(ABC): "max_new_tokens": int(self.max_new_tokens), "stop": self.prompt_template.sep, } - logger.info(f"Requert: \n{payload}") return payload def stream_call(self): @@ -127,30 +128,18 @@ class BaseChat(ABC): ai_response_text = "" try: show_info = "" + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + timeout=120, + ) - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate_stream"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # - # ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) + ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) - # for resp_text_trunck in ai_response_text: - # show_info = resp_text_trunck - # yield resp_text_trunck + "▌" - # - - #### MOCK TEST - def mock_stream_out(): - for i in range(1, 11): - time.sleep(0.5) - yield f"Message:{i}" - - for msg in mock_stream_out(): - show_info = msg - yield msg + "▌" + for resp_text_trunck in ai_response_text: + show_info = resp_text_trunck + yield resp_text_trunck + "▌" self.current_message.add_ai_message(show_info) @@ -186,13 +175,13 @@ class BaseChat(ABC): result = self.do_with_prompt_response(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): - if prompt_define_response.thoughts.get("speak"): + if hasattr(prompt_define_response.thoughts, "speak"): self.current_message.add_view_message( self.prompt_template.output_parser.parse_view_response( prompt_define_response.thoughts.get("speak"), result ) ) - elif prompt_define_response.thoughts.get("reasoning"): + elif hasattr(prompt_define_response.thoughts, "reasoning"): self.current_message.add_view_message( self.prompt_template.output_parser.parse_view_response( prompt_define_response.thoughts.get("reasoning"), result @@ -223,15 +212,18 @@ class BaseChat(ABC): def call(self): if self.prompt_template.stream_out: - yield self.stream_call() + yield self.stream_call() else: return self.nostream_call() def generate_llm_text(self) -> str: - text = self.prompt_template.template_define + self.prompt_template.sep - ### 线处理历史信息 + text = "" + if self.prompt_template.template_define: + text = self.prompt_template.template_define + self.prompt_template.sep + + ### 处理历史信息 if len(self.history_message) > self.chat_retention_rounds: - ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 + ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 for first_message in self.history_message[0].messages: if not isinstance(first_message, ViewMessage): text += ( @@ -262,8 +254,8 @@ class BaseChat(ABC): + message.content + self.prompt_template.sep ) - ### current conversation + for now_message in self.current_message.messages: text += ( now_message.type + ":" + now_message.content + self.prompt_template.sep @@ -298,34 +290,3 @@ class BaseChat(ABC): """ pass - -if __name__ == "__main__": - # - # def call_back(t, m): - # print(t) - # print(m) - # - # def my_fn(call_fn, xx): - # call_fn(1, xx) - # - # - # my_fn(call_back, "1231") - - def my_generator(): - while True: - value = yield - print('Received value:', value) - if value == 'stop': - return - - - # 创建生成器对象 - gen = my_generator() - - # 启动生成器 - next(gen) - - # 发送数据到生成器 - gen.send('Hello') - gen.send('World') - gen.send('stop') diff --git a/pilot/scene/chat_db/auto_execute/__init__.py b/pilot/scene/chat_db/auto_execute/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py new file mode 100644 index 000000000..2b8918fde --- /dev/null +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -0,0 +1,57 @@ +import json + +from pilot.scene.base_message import ( + HumanMessage, + ViewMessage, +) +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config +from pilot.common.markdown_text import ( + generate_htm_table, +) +from pilot.scene.chat_db.auto_execute.prompt import prompt + +CFG = Config() + + +class ChatWithDbAutoExecute(BaseChat): + chat_scene: str = ChatScene.ChatWithDbExecute.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatWithDbExecute, + chat_session_id=chat_session_id, + current_user_input=user_input) + if not db_name: + raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!") + self.db_name = db_name + self.database = CFG.local_db + # 准备DB信息(拿到指定库的链接) + self.db_connect = self.database.get_session(self.db_name) + self.top_k: int = 5 + + def generate_input_values(self): + input_values = { + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect) + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return self.database.run(self.db_connect, prompt_response.sql) + + + +if __name__ == "__main__": + ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}" + ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '') + print(ss) + json.loads(ss) \ No newline at end of file diff --git a/pilot/scene/chat_db/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py similarity index 90% rename from pilot/scene/chat_db/out_parser.py rename to pilot/scene/chat_db/auto_execute/out_parser.py index 307aff680..cb059feb8 100644 --- a/pilot/scene/chat_db/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -22,7 +22,9 @@ class DbChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text): - response = json.loads(super().parse_prompt_response(model_out_text)) + clean_str = super().parse_prompt_response(model_out_text); + print("clean prompt response:", clean_str) + response = json.loads(clean_str) sql, thoughts = response["sql"], response["thoughts"] return SqlAction(sql, thoughts) diff --git a/pilot/scene/chat_db/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py similarity index 63% rename from pilot/scene/chat_db/prompt.py rename to pilot/scene/chat_db/auto_execute/prompt.py index aeaf994c0..9a381345f 100644 --- a/pilot/scene/chat_db/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -1,36 +1,30 @@ import json +import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_db.out_parser import DbChatOutputParser, SqlAction +from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction from pilot.common.schema import SeparatorStyle CFG = Config() PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" -PROMPT_SUFFIX = """Only use the following tables: -{table_info} - -Question: {input} - -""" _DEFAULT_TEMPLATE = """ You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. +If the given table is beyond the scope of use, do not use it forcibly. Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. """ -_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question. -Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database. -Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers. -Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. -Pay attention to use CURDATE() function to get the current date, if the question involves "today". +PROMPT_SUFFIX = """Only use the following tables: +{table_info} +Question: {input} """ @@ -49,17 +43,16 @@ RESPONSE_FORMAT = { } RESPONSE_FORMAT_SIMPLE = { - "thoughts": "thoughts summary to say to user", + "thoughts": "thoughts summary to say to user", "sql": "SQL Query to run", } - PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False -chat_db_prompt = PromptTemplate( - template_scene=ChatScene.ChatWithDb.value, +prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDbExecute.value, input_variables=["input", "table_info", "dialect", "top_k", "response"], response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4), template_define=PROMPT_SCENE_DEFINE, @@ -69,5 +62,5 @@ chat_db_prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), ) +CFG.prompt_templates.update({prompt.template_scene: prompt}) -CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt}) diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py deleted file mode 100644 index 745e9804d..000000000 --- a/pilot/scene/chat_db/chat.py +++ /dev/null @@ -1,240 +0,0 @@ -import requests -import datetime -import threading -import json -import traceback -from urllib.parse import urljoin -from sqlalchemy import ( - MetaData, - Table, - create_engine, - inspect, - select, - text, -) -from typing import Any, Iterable, List, Optional - -from pilot.scene.base_message import ( - BaseMessage, - SystemMessage, - HumanMessage, - AIMessage, - ViewMessage, -) -from pilot.scene.base_chat import BaseChat, logger, headers -from pilot.scene.base import ChatScene -from pilot.common.sql_database import Database -from pilot.configs.config import Config -from pilot.scene.chat_db.out_parser import SqlAction -from pilot.configs.model_config import LOGDIR, DATASETS_DIR -from pilot.utils import ( - build_logger, - server_error_msg, -) -from pilot.common.markdown_text import ( - generate_markdown_table, - generate_htm_table, - datas_to_table_html, -) -from pilot.scene.chat_db.prompt import chat_db_prompt -from pilot.out_parser.base import BaseOutputParser -from pilot.scene.chat_db.out_parser import DbChatOutputParser - -CFG = Config() - - -class ChatWithDb(BaseChat): - chat_scene: str = ChatScene.ChatWithDb.value - - """Number of results to return from the query""" - - def __init__(self, chat_session_id, db_name, user_input): - """ """ - super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input) - if not db_name: - raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!") - self.db_name = db_name - self.database = CFG.local_db - # 准备DB信息(拿到指定库的链接) - self.db_connect = self.database.get_session(self.db_name) - self.top_k: int = 5 - - def generate_input_values(self): - input_values = { - "input": self.current_user_input, - "top_k": str(self.top_k), - "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect) - } - return input_values - - def do_with_prompt_response(self, prompt_response): - return self.database.run(self.db_connect, prompt_response.sql) - - # def call(self) -> str: - # input_values = { - # "input": self.current_user_input, - # "top_k": str(self.top_k), - # "dialect": self.database.dialect, - # "table_info": self.database.table_simple_info(self.db_connect), - # # "stop": self.sep_style, - # } - # - # ### Chat sequence advance - # self.current_message.chat_order = len(self.history_message) + 1 - # self.current_message.add_user_message(self.current_user_input) - # self.current_message.start_date = datetime.datetime.now() - # # TODO - # self.current_message.tokens = 0 - # - # current_prompt = self.prompt_template.format(**input_values) - # - # ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - # if self.history_message: - # ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - # logger.info( - # f"There are already {len(self.history_message)} rounds of conversations!" - # ) - # - # self.current_message.add_system_message(current_prompt) - # - # payload = { - # "model": self.llm_model, - # "prompt": self.generate_llm_text(), - # "temperature": float(self.temperature), - # "max_new_tokens": int(self.max_new_tokens), - # "stop": self.prompt_template.sep, - # } - # logger.info(f"Requert: \n{payload}") - # ai_response_text = "" - # try: - # ### 走非流式的模型服务接口 - # - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # ai_response_text = ( - # self.prompt_template.output_parser.parse_model_server_out(response) - # ) - # self.current_message.add_ai_message(ai_response_text) - # prompt_define_response = ( - # self.prompt_template.output_parser.parse_prompt_response( - # ai_response_text - # ) - # ) - # - # result = self.database.run(self.db_connect, prompt_define_response.sql) - # - # if hasattr(prompt_define_response, "thoughts"): - # if prompt_define_response.thoughts.get("speak"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("speak"), result - # ) - # ) - # elif prompt_define_response.thoughts.get("reasoning"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("reasoning"), result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts, result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response, result - # ) - # ) - # - # except Exception as e: - # print(traceback.format_exc()) - # logger.error("model response parase faild!" + str(e)) - # self.current_message.add_view_message( - # f"""ERROR!{str(e)}\n {ai_response_text} """ - # ) - # ### 对话记录存储 - # self.memory.append(self.current_message) - - def chat_show(self): - ret = [] - # 单论对话只能有一次User 记录 和一次 AI 记录 - # TODO 推理过程前端展示。。。 - for message in self.current_message.messages: - if isinstance(message, HumanMessage): - ret[-1][-2] = message.content - # 是否展示推理过程 - if isinstance(message, ViewMessage): - ret[-1][-1] = message.content - - return ret - - @property - def chat_type(self) -> str: - return ChatScene.ChatExecution.value - - -if __name__ == "__main__": - # chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息") - # - # chat.call() - # - # resp = chat.chat_show() - # - # print(vars(resp)) - - # memory = FileHistoryMemory("test123") - # once1 = OnceConversation() - # once1.add_user_message("问题测试") - # once1.add_system_message("prompt1") - # once1.add_system_message("prompt2") - # once1.chat_order = 1 - # once1.set_start_time(datetime.datetime.now()) - # memory.append(once1) - # - # once = OnceConversation() - # once.add_user_message("问题测试2") - # once.add_system_message("prompt3") - # once.add_system_message("prompt4") - # once.chat_order = 2 - # once.set_start_time(datetime.datetime.now()) - # memory.append(once) - - db: Database = CFG.local_db - db_connect = db.get_session("gpt-user") - data = db.run(db_connect, "select * from users") - print(generate_htm_table(data)) - - # - # print(db.run(db_connect, "select * from users")) - # - # # - # # def print_numbers(): - # # db_connect1 = db.get_session("dbgpt-test") - # # cursor1 = db_connect1.execute(text("select * from test_name")) - # # if cursor1.returns_rows: - # # result1 = cursor1.fetchall() - # # print( result1) - # # - # # - # # # 创建线程 - # # t = threading.Thread(target=print_numbers) - # # # 启动线程 - # # t.start() - # - # print(db.run(db_connect, "select * from tran_order")) - # - # print(db.run(db_connect, "select count(*) as aa from tran_order")) - # - # print(db.table_simple_info(db_connect)) - # my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] - # index = 3 - # last_three_elements = my_list[-index:] - # print(last_three_elements) diff --git a/pilot/scene/chat_db/professional_qa/__init__.py b/pilot/scene/chat_db/professional_qa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py new file mode 100644 index 000000000..fbf5a8bb4 --- /dev/null +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -0,0 +1,56 @@ +from pilot.scene.base_message import ( + HumanMessage, + ViewMessage, +) +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config +from pilot.common.markdown_text import ( + generate_htm_table, +) +from pilot.scene.chat_db.professional_qa.prompt import prompt + +CFG = Config() + + +class ChatWithDbQA(BaseChat): + chat_scene: str = ChatScene.ChatWithDbQA.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatWithDbQA, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.db_name = db_name + if db_name: + self.database = CFG.local_db + # 准备DB信息(拿到指定库的链接) + self.db_connect = self.database.get_session(self.db_name) + self.top_k: int = 5 + + def generate_input_values(self): + + table_info = "" + dialect = "mysql" + if self.db_name: + table_info = self.database.table_simple_info(self.db_connect) + dialect = self.database.dialect + + input_values = { + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": dialect, + "table_info": table_info + } + return input_values + + def do_with_prompt_response(self, prompt_response): + if self.auto_execute: + return self.database.run(self.db_connect, prompt_response.sql) + else: + return prompt_response diff --git a/pilot/scene/chat_db/professional_qa/out_parser.py b/pilot/scene/chat_db/professional_qa/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_db/professional_qa/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_db/professional_qa/prompt.py b/pilot/scene/chat_db/professional_qa/prompt.py new file mode 100644 index 000000000..00fc87c03 --- /dev/null +++ b/pilot/scene/chat_db/professional_qa/prompt.py @@ -0,0 +1,48 @@ +import json +import importlib +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser +from pilot.common.schema import SeparatorStyle + +CFG = Config() + +PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """ + +PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: +{table_info} + +Question: {input} + +""" + +_DEFAULT_TEMPLATE = """ +You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. +Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDbQA.value, + input_variables=["input", "table_info", "dialect", "top_k"], + response_format=None, + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX , + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 210b2ad77..27e79e3a5 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -10,8 +10,7 @@ from pilot.scene.base import ChatScene from pilot.configs.config import Config from pilot.commands.command import execute_command from pilot.prompts.generator import PluginPromptGenerator - -from pilot.scene.chat_execution.prompt import chat_plugin_prompt +from pilot.scene.chat_execution.prompt import prompt CFG = Config() @@ -20,8 +19,12 @@ class ChatWithPlugin(BaseChat): plugins_prompt_generator:PluginPromptGenerator select_plugin: str = None - def __init__(self, chat_session_id, user_input, plugin_selector:str=None): - super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input) + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None): + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatExecution, + chat_session_id=chat_session_id, + current_user_input=user_input) self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator.command_registry = CFG.command_registry # 加载插件中可用命令 diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index f3f9e683e..ff5b6a0d7 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -20,8 +20,8 @@ class PluginChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: response = json.loads(super().parse_prompt_response(model_out_text)) - sql, thoughts = response["command"], response["thoughts"] - return PluginAction(sql, thoughts) + command, thoughts = response["command"], response["thoughts"] + return PluginAction(command, thoughts) def parse_view_response(self, ai_text) -> str: return super().parse_view_response(ai_text) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index 44f564afe..6875689cf 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -1,4 +1,5 @@ import json +import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene @@ -50,7 +51,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value ### Whether the model service is streaming output PROMPT_NEED_NEED_STREAM_OUT = False -chat_plugin_prompt = PromptTemplate( +prompt = PromptTemplate( template_scene=ChatScene.ChatExecution.value, input_variables=["input", "constraints", "commands_infos", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), @@ -62,4 +63,4 @@ chat_plugin_prompt = PromptTemplate( ), ) -CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt}) +CFG.prompt_templates.update({prompt.template_scene: prompt}) \ No newline at end of file diff --git a/pilot/scene/chat_execution/prompt_with_command.py b/pilot/scene/chat_execution/prompt_with_command.py deleted file mode 100644 index e3469d7c2..000000000 --- a/pilot/scene/chat_execution/prompt_with_command.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -from pilot.prompts.prompt_new import PromptTemplate -from pilot.configs.config import Config -from pilot.scene.base import ChatScene -from pilot.common.schema import SeparatorStyle - -from pilot.scene.chat_execution.out_parser import PluginChatOutputParser - - -CFG = Config() - -PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications.""" - -PROMPT_SUFFIX = """ -Goals: - {input} - -""" - -_DEFAULT_TEMPLATE = """ -Constraints: - Exclusively use the commands listed in double quotes e.g. "command name" - Reflect on past decisions and strategies to refine your approach. - Constructively self-criticize your big-picture behavior constantly. - {constraints} - -Commands: - {commands_infos} -""" - - -PROMPT_RESPONSE = """You must respond in JSON format as following format: -{response} - -Ensure the response is correct json and can be parsed by Python json.loads -""" - -RESPONSE_FORMAT = { - "thoughts": { - "text": "thought", - "reasoning": "reasoning", - "plan": "- short bulleted\n- list that conveys\n- long-term plan", - "criticism": "constructive self-criticism", - "speak": "thoughts summary to say to user", - }, - "command": {"name": "command name", "args": {"arg name": "value"}}, -} - -PROMPT_SEP = SeparatorStyle.SINGLE.value -### Whether the model service is streaming output -PROMPT_NEED_NEED_STREAM_OUT = False - -chat_plugin_prompt = PromptTemplate( - template_scene=ChatScene.ChatExecution.value, - input_variables=["input", "table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT, indent=4), - template_define=PROMPT_SCENE_DEFINE, - template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, - stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=PluginChatOutputParser( - sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT - ), -) - -CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt}) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 97c547390..7a346cbda 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -1,8 +1,14 @@ from pilot.scene.base_chat import BaseChat from pilot.singleton import Singleton -from pilot.scene.chat_db.chat import ChatWithDb +import inspect +import importlib from pilot.scene.chat_execution.chat import ChatWithPlugin - +from pilot.scene.chat_normal.chat import ChatNormal +from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA +from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute +from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge +from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge +from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge class ChatFactory(metaclass=Singleton): @staticmethod @@ -13,5 +19,5 @@ class ChatFactory(metaclass=Singleton): if cls.chat_scene == chat_mode: implementation = cls(**kwargs) if implementation == None: - raise Exception("Invalid implementation name:" + chat_mode) + raise Exception(f"Invalid implementation name:{chat_mode}") return implementation diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py new file mode 100644 index 000000000..7b9a11f85 --- /dev/null +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -0,0 +1,69 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) + +from pilot.scene.chat_normal.prompt import prompt +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatNewKnowledge (BaseChat): + chat_scene: str = ChatScene.ChatNewKnowledge.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatNewKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.knowledge_name = knowledge_name + vector_store_config = { + "vector_store_name": knowledge_name, + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + + def generate_input_values(self): + docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = docs[:2000] + input_values = { + "context": docs, + "question": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatNewKnowledge.value diff --git a/pilot/scene/chat_knowledge/custom/out_parser.py b/pilot/scene/chat_knowledge/custom/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_knowledge/custom/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py new file mode 100644 index 000000000..175deaddb --- /dev/null +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -0,0 +1,43 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatNewKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py new file mode 100644 index 000000000..978570d91 --- /dev/null +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -0,0 +1,66 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) + +from pilot.scene.chat_normal.prompt import prompt +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatDefaultKnowledge (BaseChat): + chat_scene: str = ChatScene.ChatKnowledge.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input) + vector_store_config = { + "vector_store_name": "default", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + def generate_input_values(self): + docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = docs[:2000] + input_values = { + "context": docs, + "question": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatKnowledge.value diff --git a/pilot/scene/chat_knowledge/default/out_parser.py b/pilot/scene/chat_knowledge/default/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_knowledge/default/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py new file mode 100644 index 000000000..d2ba473ab --- /dev/null +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -0,0 +1,43 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py new file mode 100644 index 000000000..0c54f6001 --- /dev/null +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -0,0 +1,71 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) + +from pilot.scene.chat_normal.prompt import prompt +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatUrlKnowledge (BaseChat): + chat_scene: str = ChatScene.ChatUrlKnowledge.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatUrlKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.url = url + vector_store_config = { + "vector_store_name": url, + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + file_path=url, + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + # url soruce in vector + self.knowledge_embedding_client.knowledge_embedding() + + def generate_input_values(self): + docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = docs[:2000] + input_values = { + "context": docs, + "question": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatUrlKnowledge.value diff --git a/pilot/scene/chat_knowledge/url/out_parser.py b/pilot/scene/chat_knowledge/url/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_knowledge/url/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py new file mode 100644 index 000000000..a5c1fe226 --- /dev/null +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -0,0 +1,43 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatUrlKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py new file mode 100644 index 000000000..edd6ac53c --- /dev/null +++ b/pilot/scene/chat_normal/chat.py @@ -0,0 +1,43 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) +from pilot.scene.chat_normal.prompt import prompt + +CFG = Config() + + +class ChatNormal(BaseChat): + chat_scene: str = ChatScene.ChatNormal.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatNormal, + chat_session_id=chat_session_id, + current_user_input=user_input) + + def generate_input_values(self): + input_values = { + "input": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatNormal.value diff --git a/pilot/scene/chat_normal/out_parser.py b/pilot/scene/chat_normal/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_normal/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index fd21f2102..7a11cf3ff 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -1,31 +1,33 @@ import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser -def stream_write_and_read(lst): - # 对lst使用yield from进行可迭代对象的扁平化 - yield from lst - while True: - val = yield - lst.append(val) +CFG = Config() + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatNormal.value, + input_variables=["input"], + response_format=None, + template_define=None, + template=None, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) -if __name__ == "__main__": - # 创建一个空列表 - my_list = [] +CFG.prompt_templates.update({prompt.template_scene: prompt}) - # 使用生成器写入数据 - stream_writer = stream_write_and_read(my_list) - next(stream_writer) - stream_writer.send(10) - print(1) - stream_writer.send(20) - print(2) - stream_writer.send(30) - print(3) - # 使用生成器读取数据 - stream_reader = stream_write_and_read(my_list) - next(stream_reader) - print(stream_reader.send(None)) - print(stream_reader.send(None)) - print(stream_reader.send(None)) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 2e8f61016..515701255 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - +import traceback import argparse import datetime import json @@ -9,7 +9,6 @@ import shutil import sys import time import uuid -from urllib.parse import urljoin import gradio as gr import requests @@ -216,19 +215,26 @@ def post_process_code(code): return code -def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: +def get_chat_mode(selected, param=None) -> ChatScene: if chat_mode_title['chat_use_plugin'] == selected: return ChatScene.ChatExecution elif chat_mode_title['knowledge_qa'] == selected: + mode= param if mode == conversation_types["default_knownledge"]: return ChatScene.ChatKnowledge elif mode == conversation_types["custome"]: return ChatScene.ChatNewKnowledge + elif mode == conversation_types["url"]: + return ChatScene.ChatUrlKnowledge + else: + return ChatScene.ChatNormal else: - if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: - return ChatScene.ChatWithDb + sql_mode= param + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + return ChatScene.ChatWithDbExecute + else: + return ChatScene.ChatWithDbQA - return ChatScene.ChatNormal def chatbot_callback(state, message): print(f"chatbot_callback:{message}") @@ -237,244 +243,99 @@ def chatbot_callback(state, message): def http_bot( - state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request + state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name ): - logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}") - start_tstamp = time.time() - scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector) - print(f"now chat scene:{scene.value}") - model_name = CFG.LLM_MODEL - if ChatScene.ChatWithDb == scene: - logger.info("chat with db mode use new architecture design!") + logger.info(f"User message send!{state.conv_id},{selected}") + if chat_mode_title['knowledge_qa'] == selected: + scene: ChatScene = get_chat_mode(selected, mode) + elif chat_mode_title['chat_use_plugin'] == selected: + scene: ChatScene = get_chat_mode(selected) + else: + scene: ChatScene = get_chat_mode(selected, sql_mode) + print(f"chat scene:{scene.value}") + + if ChatScene.ChatWithDbExecute == scene: chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "db_name": db_selector, + "user_input": state.last_user_input + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatWithDbQA == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "db_name": db_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call() - - state.messages[-1][-1] = chat.current_ai_response() - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - elif ChatScene.ChatExecution == scene: - logger.info("plugin mode use new architecture design!") chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "plugin_selector": plugin_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - strem_generate = chat.stream_call() + elif ChatScene.ChatNormal == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatKnowledge == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatNewKnowledge == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + "knowledge_name": knowledge_name + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatUrlKnowledge == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + "url": url_input + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - for msg in strem_generate: - state.messages[-1][-1] = msg + if not chat.prompt_template.stream_out: + logger.info("not stream out, wait model response!") + state.messages[-1][-1] = chat.nostream_call() + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + else: + logger.info("stream out start!") + try: + stream_gen = chat.stream_call() + for msg in stream_gen: + state.messages[-1][-1] = msg + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + except Exception as e: + print(traceback.format_exc()) + state.messages[-1][-1] = "Error:" + str(e) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - # def generate_numbers(): - # for i in range(10): - # time.sleep(0.5) - # yield f"Message:{i}" - # - # def showMessage(message): - # return message - # - # for n in generate_numbers(): - # state.messages[-1][-1] = n + "▌" - # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - else: - - dbname = db_selector - # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 - if state.skip_next: - # This generate call is skipped due to invalid inputs - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - return - - if len(state.messages) == state.offset + 2: - query = state.messages[-2][1] - - template_name = "conv_one_shot" - new_state = conv_templates[template_name].copy() - # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? - # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 - if db_selector: - new_state.append_message( - new_state.roles[0], gen_sqlgen_conversation(dbname) + query - ) - new_state.append_message(new_state.roles[1], None) - else: - new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - - new_state.conv_id = uuid.uuid4().hex - state = new_state - else: - ### 后续对话 - query = state.messages[-2][1] - # 第一轮对话需要加入提示Prompt - if mode == conversation_types["custome"]: - template_name = "conv_one_shot" - new_state = conv_templates[template_name].copy() - # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? - # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 - if db_selector: - new_state.append_message( - new_state.roles[0], gen_sqlgen_conversation(dbname) + query - ) - new_state.append_message(new_state.roles[1], None) - else: - new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - state = new_state - - prompt = state.get_prompt() - skip_echo_len = len(prompt.replace("", " ")) + 1 - if mode == conversation_types["default_knownledge"] and not db_selector: - vector_store_config = { - "vector_store_name": "default", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - if mode == conversation_types["custome"] and not db_selector: - print("vector store name: ", vector_store_name["vs_name"]) - vector_store_config = { - "vector_store_name": vector_store_name["vs_name"], - "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - if mode == conversation_types["url"] and url_input: - print("url: ", url_input) - vector_store_config = { - "vector_store_name": url_input, - "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path=url_input, - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - # Make requests - payload = { - "model": model_name, - "prompt": prompt, - "temperature": float(temperature), - "max_new_tokens": int(max_new_tokens), - "stop": state.sep - if state.sep_style == SeparatorStyle.SINGLE - else state.sep2, - } - logger.info(f"Requert: \n{payload}") - - # 流式输出 - state.messages[-1][-1] = "▌" - yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 - - try: - # Stream output - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) - - """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. - """ - if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: - output = data["text"][skip_echo_len:].strip() - else: - output = data["text"].strip() - - output = post_process_code(output) - state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - ) * 5 - else: - output = ( - data["text"] + f" (error_code: {data['error_code']})" - ) - state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - except requests.exceptions.RequestException as e: - state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - state.messages[-1][-1] = state.messages[-1][-1][:-1] - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - - # 记录运行日志 - finish_tstamp = time.time() - logger.info(f"{output}") - - with open(get_conv_log_filename(), "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "start": round(start_tstamp, 4), - "finish": round(start_tstamp, 4), - "state": state.dict(), - "ip": request.client.host, - } - fout.write(json.dumps(data) + "\n") - + if state.messages[-1][-1].endwith("▌"): + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 block_css = ( code_highlight_css @@ -556,6 +417,7 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True, + name="db_selector" ).style(container=False) sql_mode = gr.Radio( @@ -565,6 +427,7 @@ def build_single_model_ui(): ], show_label=False, value=get_lang_text("sql_generate_mode_none"), + name="sql_mode" ) sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) @@ -581,7 +444,8 @@ def build_single_model_ui(): value="", interactive=True, show_label=True, - type="value" + type="value", + name="plugin_selector" ).style(container=False) def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData @@ -602,13 +466,14 @@ def build_single_model_ui(): ], show_label=False, value=llm_native_dialogue, + name="mode" ) vs_setting = gr.Accordion( - get_lang_text("configure_knowledge_base"), open=False + get_lang_text("configure_knowledge_base"), open=False, visible=False ) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) - url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True) + url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False, name="url_input") def show_url_input(evt:gr.SelectData): if evt.value == url_knowledge_dialogue: return gr.update(visible=True) @@ -619,7 +484,7 @@ def build_single_model_ui(): with vs_setting: vs_name = gr.Textbox( - label=get_lang_text("new_klg_name"), lines=1, interactive=True + label=get_lang_text("new_klg_name"), lines=1, interactive=True, name = "vs_name" ) vs_add = gr.Button(get_lang_text("add_as_new_klg")) with gr.Column() as doc2vec: @@ -664,10 +529,14 @@ def build_single_model_ui(): clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False) gr.Markdown(learn_more_markdown) + + params = [plugin_selector, mode, sql_mode, db_selector, url_input, vs_name] + + btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], + [state, selected, temperature, max_output_tokens] + params, [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -676,7 +545,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], + [state, selected, temperature, max_output_tokens]+ params, [state, chatbot] + btn_list, ) @@ -684,7 +553,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], + [state, selected, temperature, max_output_tokens]+ params, [state, chatbot] + btn_list, ) vs_add.click( diff --git a/requirements.txt b/requirements.txt index f476a4b23..b2f582eed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ fonttools==4.38.0 frozenlist==1.3.3 huggingface-hub==0.13.4 importlib-resources==5.12.0 + kiwisolver==1.4.4 matplotlib==3.7.0 multidict==6.0.4