From 5150cfcf55a20d39bdcbb7b4f95031eb0995bbc7 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 30 May 2023 17:20:37 +0800 Subject: [PATCH] add plugin mode --- pilot/commands/built_in/__init__.py | 0 pilot/commands/{ => built_in}/audio_text.py | 0 pilot/commands/{ => built_in}/image_gen.py | 0 pilot/commands/commands_load.py | 29 -- pilot/conversation.py | 10 +- pilot/language/lang_content_mapping.py | 11 +- pilot/out_parser/base.py | 39 ++- pilot/scene/base_chat.py | 19 +- pilot/scene/chat_knowledge/custom/__init__.py | 0 .../scene/chat_knowledge/default/__init__.py | 0 pilot/scene/chat_knowledge/url/__init__.py | 0 pilot/server/webserver.py | 250 ++++++++++-------- pilot/source_embedding/external/__init__.py | 0 pilot/source_embedding/knowledge_embedding.py | 7 + 14 files changed, 212 insertions(+), 153 deletions(-) create mode 100644 pilot/commands/built_in/__init__.py rename pilot/commands/{ => built_in}/audio_text.py (100%) rename pilot/commands/{ => built_in}/image_gen.py (100%) delete mode 100644 pilot/commands/commands_load.py create mode 100644 pilot/scene/chat_knowledge/custom/__init__.py create mode 100644 pilot/scene/chat_knowledge/default/__init__.py create mode 100644 pilot/scene/chat_knowledge/url/__init__.py create mode 100644 pilot/source_embedding/external/__init__.py diff --git a/pilot/commands/built_in/__init__.py b/pilot/commands/built_in/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/commands/audio_text.py b/pilot/commands/built_in/audio_text.py similarity index 100% rename from pilot/commands/audio_text.py rename to pilot/commands/built_in/audio_text.py diff --git a/pilot/commands/image_gen.py b/pilot/commands/built_in/image_gen.py similarity index 100% rename from pilot/commands/image_gen.py rename to pilot/commands/built_in/image_gen.py diff --git a/pilot/commands/commands_load.py b/pilot/commands/commands_load.py deleted file mode 100644 index a6fad3db2..000000000 --- a/pilot/commands/commands_load.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pilot.configs.config import Config -from pilot.prompts.generator import PromptGenerator -from pilot.prompts.prompt import build_default_prompt_generator - - -class CommandsLoad: - """ - Load Plugins Commands Info , help build system prompt! - """ - - def __init__(self) -> None: - self.command_registry = None - - def getCommandInfos( - self, prompt_generator: Optional[PromptGenerator] = None - ) -> str: - cfg = Config() - if prompt_generator is None: - prompt_generator = build_default_prompt_generator() - for plugin in cfg.plugins: - if not plugin.can_handle_post_prompt(): - continue - prompt_generator = plugin.post_prompt(prompt_generator) - self.prompt_generator = prompt_generator - command_infos = "" - command_infos += f"\n\n{prompt_generator.commands()}" - return command_infos diff --git a/pilot/conversation.py b/pilot/conversation.py index d2f6565ca..0673b49c3 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -263,6 +263,14 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回 # """ default_conversation = conv_one_shot + +chat_mode_title = { + "sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"), + "chat_use_plugin": get_lang_text("chat_use_plugin"), + "knowledge_qa": get_lang_text("knowledge_qa"), + +} + conversation_sql_mode = { "auto_execute_ai_response": get_lang_text("sql_generate_mode_direct"), "dont_execute_ai_response": get_lang_text("sql_generate_mode_none"), @@ -274,7 +282,7 @@ conversation_types = { "knowledge_qa_type_default_knowledge_base_dialogue" ), "custome": get_lang_text("knowledge_qa_type_add_knowledge_base_dialogue"), - "auto_execute_plugin": get_lang_text("dialogue_use_plugin"), + "url": get_lang_text("knowledge_qa_type_url_knowledge_dialogue"), } conv_templates = { diff --git a/pilot/language/lang_content_mapping.py b/pilot/language/lang_content_mapping.py index 5d165b51c..bcea7ed3c 100644 --- a/pilot/language/lang_content_mapping.py +++ b/pilot/language/lang_content_mapping.py @@ -14,17 +14,22 @@ lang_dicts = { "knowledge_qa_type_llm_native_dialogue": "LLM原生对话", "knowledge_qa_type_default_knowledge_base_dialogue": "默认知识库对话", "knowledge_qa_type_add_knowledge_base_dialogue": "新增知识库对话", - "dialogue_use_plugin": "对话使用插件", + "knowledge_qa_type_url_knowledge_dialogue": "URL网页知识对话", "create_knowledge_base": "新建知识库", "sql_schema_info": "数据库{}的Schema信息如下: {}\n", "current_dialogue_mode": "当前对话模式", "database_smart_assistant": "数据库智能助手", "sql_vs_setting": "自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力", "knowledge_qa": "知识问答", + "chat_use_plugin": "插件模式", + "dialogue_use_plugin": "对话使用插件", + "select_plugin": "选择插件", "configure_knowledge_base": "配置知识库", "new_klg_name": "新知识库名称", + "url_input_label": "输入网页地址", "add_as_new_klg": "添加为新知识库", "add_file_to_klg": "向知识库中添加文件", + "upload_file": "上传文件", "add_file": "添加文件", "upload_and_load_to_klg": "上传并加载到知识库", @@ -47,14 +52,18 @@ lang_dicts = { "knowledge_qa_type_llm_native_dialogue": "LLM native dialogue", "knowledge_qa_type_default_knowledge_base_dialogue": "Default documents", "knowledge_qa_type_add_knowledge_base_dialogue": "Added documents", + "knowledge_qa_type_url_knowledge_dialogue": "Chat with url", "dialogue_use_plugin": "Dialogue Extension", "create_knowledge_base": "Create Knowledge Base", "sql_schema_info": "the schema information of database {}: {}\n", "current_dialogue_mode": "Current dialogue mode", "database_smart_assistant": "Database smart assistant", "sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn", + "chat_use_plugin": "Plugin Mode", + "select_plugin": "Select Plugin", "knowledge_qa": "Documents QA", "configure_knowledge_base": "Configure Documents", + "url_input_label": "Please input url", "new_klg_name": "New document name", "add_as_new_klg": "Add as new documents", "add_file_to_klg": "Add file to documents", diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 36ca8eb9c..57f0a7f7e 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -18,11 +18,14 @@ import re from pydantic import BaseModel, Extra, Field, root_validator from pilot.configs.model_config import LOGDIR -from pilot.prompts.base import PromptValue +from pilot.configs.config import Config T = TypeVar("T") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +CFG = Config() + + class BaseOutputParser(ABC): """Class to parse the output of an LLM call. @@ -33,9 +36,39 @@ class BaseOutputParser(ABC): self.sep = sep self.is_stream_out = is_stream_out + def __post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + # TODO 后续和模型绑定 def _parse_model_stream_resp(self, response, sep: str): - pass + + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ + if data["error_code"] == 0: + if "vicuna" in CFG.LLM_MODEL: + + output = data["text"].strip() + else: + output = data["text"].strip() + + output = self.__post_process_code(output) + yield output + else: + output = ( + data["text"] + f" (error_code: {data['error_code']})" + ) + yield output def _parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() @@ -64,7 +97,7 @@ class BaseOutputParser(ABC): else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) - def parse_model_server_out(self, response) -> str: + def parse_model_server_out(self, response): """ parse the model server http response Args: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 7a1c77781..a376759bc 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import datetime import traceback +import json from pydantic import BaseModel, Field, root_validator, validator, Extra from typing import ( Any, @@ -41,6 +42,7 @@ headers = {"User-Agent": "dbgpt Client"} CFG = Config() + class BaseChat(ABC): chat_scene: str = None llm_model: Any = None @@ -89,8 +91,7 @@ class BaseChat(ABC): def do_with_prompt_response(self, prompt_response): pass - - def call(self): + def call(self, show_fn, state): input_values = self.generate_input_values() ### Chat sequence advance @@ -164,6 +165,7 @@ class BaseChat(ABC): prompt_define_response, result ) ) + show_fn(state, self.current_ai_response()) else: response = requests.post( urljoin(CFG.MODEL_SERVER, "generate_stream"), @@ -171,9 +173,14 @@ class BaseChat(ABC): json=payload, timeout=120, ) - #TODO - + show_fn(state, "▌") + ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) + show_info ="" + for resp_text_trunck in ai_response_text: + show_info = resp_text_trunck + show_fn(state, resp_text_trunck + "▌") + self.current_message.add_ai_message(show_info) except Exception as e: print(traceback.format_exc()) @@ -181,9 +188,11 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) + show_fn(state, self.current_ai_response()) ### 对话记录存储 self.memory.append(self.current_message) + def generate_llm_text(self) -> str: text = self.prompt_template.template_define + self.prompt_template.sep ### 线处理历史信息 @@ -229,8 +238,6 @@ class BaseChat(ABC): return text - def chat_show(self): - pass # 暂时为了兼容前端 def current_ai_response(self) -> str: diff --git a/pilot/scene/chat_knowledge/custom/__init__.py b/pilot/scene/chat_knowledge/custom/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/default/__init__.py b/pilot/scene/chat_knowledge/default/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/url/__init__.py b/pilot/scene/chat_knowledge/url/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 9d04280d6..2f77277aa 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -37,6 +37,7 @@ from pilot.conversation import ( conv_templates, conversation_sql_mode, conversation_types, + chat_mode_title, default_conversation, ) from pilot.common.plugins import scan_plugins @@ -95,6 +96,11 @@ default_knowledge_base_dialogue = get_lang_text( add_knowledge_base_dialogue = get_lang_text( "knowledge_qa_type_add_knowledge_base_dialogue" ) + +url_knowledge_dialogue = get_lang_text( + "knowledge_qa_type_url_knowledge_dialogue" +) + knowledge_qa_type_list = [ llm_native_dialogue, default_knowledge_base_dialogue, @@ -115,7 +121,7 @@ def gen_sqlgen_conversation(dbname): db_connect = CFG.local_db.get_session(dbname) schemas = CFG.local_db.table_simple_info(db_connect) for s in schemas: - message += s["schema_info"] + ";" + message += s+ ";" return get_lang_text("sql_schema_info").format(dbname, message) @@ -211,9 +217,9 @@ def post_process_code(code): def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: - if "插件模式" == selected: + if chat_mode_title['chat_use_plugin'] == selected: return ChatScene.ChatExecution - elif "知识问答" == selected: + elif chat_mode_title['knowledge_qa'] == selected: if mode == conversation_types["default_knownledge"]: return ChatScene.ChatKnowledge elif mode == conversation_types["custome"]: @@ -226,37 +232,50 @@ def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: def http_bot( - state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request + state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request ): logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}") start_tstamp = time.time() - scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector) - print(f"当前对话模式:{scene.value}") + scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector) + print(f"now chat scene:{scene.value}") model_name = CFG.LLM_MODEL + def chatbot_callback(state, message): + print(f"chatbot_callback:{message}") + state.messages[-1][-1] = f"{message}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + if ChatScene.ChatWithDb == scene: - logger.info("基于DB对话走新的模式!") + logger.info("chat with db mode use new architecture design!") chat_param = { "chat_session_id": state.conv_id, "db_name": db_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call() - state.messages[-1][-1] = f"{chat.current_ai_response()}" - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat.call(show_fn=chatbot_callback, state= state) + elif ChatScene.ChatExecution == scene: - logger.info("插件模式对话走新的模式!") + logger.info("plugin mode use new architecture design!") chat_param = { "chat_session_id": state.conv_id, "plugin_selector": plugin_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call() - state.messages[-1][-1] = f"{chat.current_ai_response()}" - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat.call(chatbot_callback, state) + # def generate_numbers(): + # for i in range(10): + # time.sleep(0.5) + # yield f"Message:{i}" + # + # def showMessage(message): + # return message + # + # for n in generate_numbers(): + # state.messages[-1][-1] = n + "▌" + # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 else: dbname = db_selector @@ -284,30 +303,45 @@ def http_bot( new_state.conv_id = uuid.uuid4().hex state = new_state + else: + ### 后续对话 + query = state.messages[-2][1] + # 第一轮对话需要加入提示Prompt + if mode == conversation_types["custome"]: + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() + # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? + # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 + if db_selector: + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) + new_state.append_message(new_state.roles[1], None) + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + state = new_state prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["default_knownledge"] and not db_selector: + vector_store_config = { + "vector_store_name": "default", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) query = state.messages[-2][1] - knqa = KnownLedgeBaseQA() - state.messages[-2][1] = knqa.get_similar_answer(query) - prompt = state.get_prompt() + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["custome"] and not db_selector: - persist_dir = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb" - ) - print("向量数据库持久化地址: ", persist_dir) - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["sentence-transforms"], - vector_store_config={ - "vector_store_name": vector_store_name["vs_name"], - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - }, - ) print("vector store name: ", vector_store_name["vs_name"]) vector_store_config = { "vector_store_name": vector_store_name["vs_name"], @@ -327,6 +361,27 @@ def http_bot( state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 + if mode == conversation_types["url"] and url_input: + print("url: ", url_input) + vector_store_config = { + "vector_store_name": url_input, + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path=url_input, + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + query = state.messages[-2][1] + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) + + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 + # Make requests payload = { "model": model_name, @@ -355,13 +410,24 @@ def http_bot( for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ if data["error_code"] == 0: - output = data["text"][skip_echo_len:].strip() + if "vicuna" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len:].strip() + else: + output = data["text"].strip() + output = post_process_code(output) state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + ) * 5 else: - output = data["text"] + f" (error_code: {data['error_code']})" + output = ( + data["text"] + f" (error_code: {data['error_code']})" + ) state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( disable_btn, @@ -371,56 +437,7 @@ def http_bot( enable_btn, ) return - try: - # Stream output - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) - """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. - """ - if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: - output = data["text"][skip_echo_len:].strip() - else: - output = data["text"].strip() - - output = post_process_code(output) - state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - ) * 5 - else: - output = ( - data["text"] + f" (error_code: {data['error_code']})" - ) - state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - except requests.exceptions.RequestException as e: - state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" yield (state, state.to_gradio_chatbot()) + ( @@ -432,29 +449,29 @@ def http_bot( ) return - state.messages[-1][-1] = state.messages[-1][-1][:-1] - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - # 记录运行日志 - finish_tstamp = time.time() - logger.info(f"{output}") + # 记录运行日志 + finish_tstamp = time.time() + logger.info(f"{output}") - with open(get_conv_log_filename(), "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "start": round(start_tstamp, 4), - "finish": round(start_tstamp, 4), - "state": state.dict(), - "ip": request.client.host, - } - fout.write(json.dumps(data) + "\n") + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -477,15 +494,12 @@ def change_sql_mode(sql_mode): def change_mode(mode): - if mode in [default_knowledge_base_dialogue, llm_native_dialogue]: - return gr.update(visible=False) - else: + if mode in [add_knowledge_base_dialogue]: return gr.update(visible=True) + else: + return gr.update(visible=False) -def change_tab(): - autogpt = True - def build_single_model_ui(): notice_markdown = get_lang_text("db_gpt_introduction") @@ -548,15 +562,14 @@ def build_single_model_ui(): sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) - tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA") - tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN") + tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN") # tab_plugin.select(change_func) with tab_plugin: print("tab_plugin in...") with gr.Row(elem_id="plugin_selector"): # TODO plugin_selector = gr.Dropdown( - label="请选择插件", + label=get_lang_text("select_plugin"), choices=list(plugins_select_info().keys()), value="", interactive=True, @@ -578,6 +591,7 @@ def build_single_model_ui(): llm_native_dialogue, default_knowledge_base_dialogue, add_knowledge_base_dialogue, + url_knowledge_dialogue, ], show_label=False, value=llm_native_dialogue, @@ -586,6 +600,16 @@ def build_single_model_ui(): get_lang_text("configure_knowledge_base"), open=False ) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + + url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True) + def show_url_input(evt:gr.SelectData): + if evt.value == url_knowledge_dialogue: + return gr.update(visible=True) + else: + return gr.update(visible=False) + mode.select(fn=show_url_input, inputs=None, outputs=url_input) + + with vs_setting: vs_name = gr.Textbox( label=get_lang_text("new_klg_name"), lines=1, interactive=True @@ -636,7 +660,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -645,7 +669,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -653,7 +677,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], [state, chatbot] + btn_list, ) vs_add.click( @@ -760,8 +784,8 @@ if __name__ == "__main__": # 加载插件可执行命令 command_categories = [ - "pilot.commands.audio_text", - "pilot.commands.image_gen", + "pilot.commands.built_in.audio_text", + "pilot.commands.built_in.image_gen", ] # 排除禁用命令 command_categories = [ diff --git a/pilot/source_embedding/external/__init__.py b/pilot/source_embedding/external/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 316667dee..8f411657d 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -11,6 +11,7 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding +from pilot.source_embedding.url_embedding import URLEmbedding from pilot.vector_store.connector import VectorStoreConnector CFG = Config() @@ -61,6 +62,12 @@ class KnowledgeEmbedding: model_name=self.model_name, vector_store_config=self.vector_store_config, ) + elif self.file_type == "url": + embedding = URLEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) return embedding