From 314920b6e12031f60b9fbf8c15916b83ff48da77 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Tue, 4 Jul 2023 17:19:18 +0800 Subject: [PATCH] WEB API independent --- pilot/model/llm_out/vicuna_base_llm.py | 1 + pilot/scene/base_chat.py | 2 +- pilot/scene/chat_db/auto_execute/out_parser.py | 5 +++-- pilot/scene/chat_db/auto_execute/prompt.py | 3 +-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pilot/model/llm_out/vicuna_base_llm.py b/pilot/model/llm_out/vicuna_base_llm.py index 30033860b..d4fcaa33d 100644 --- a/pilot/model/llm_out/vicuna_base_llm.py +++ b/pilot/model/llm_out/vicuna_base_llm.py @@ -11,6 +11,7 @@ def generate_stream( """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py""" prompt = params["prompt"] l_prompt = len(prompt) + prompt= prompt.replace("ai:", "assistant:").replace("human:", "user:") temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 6eecbbc71..449df3fe4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -115,7 +115,7 @@ class BaseChat(ABC): payload = { "model": self.llm_model, - "prompt": self.generate_llm_text().replace("ai:", "assistant:"), + "prompt": self.generate_llm_text(), "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index c7237671b..64432520e 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -42,8 +42,9 @@ class DbChatOutputParser(BaseOutputParser): html_table = df.to_html(index=False, escape=False) html = f"{table_style}{html_table}" else: - html = df.to_html(index=False, escape=False, sparsify=False) - html = "".join(html.split()) + html_table = df.to_html(index=False, escape=False, sparsify=False) + table_str = "".join(html_table.split()) + html = f"""
{table_str}
""" view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") return view_text diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 1db3d7791..26d85cd61 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -11,7 +11,7 @@ CFG = Config() PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ -You are a SQL expert. Given an input question, create a syntactically correct {dialect} query. +You are a SQL expert. Given an input question, create a syntactically correct {dialect} sql. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. Use as few tables as possible when querying. @@ -51,7 +51,6 @@ prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), example_selector=sql_data_example, - # example_selector=None, temperature=PROMPT_TEMPERATURE ) CFG.prompt_templates.update({prompt.template_scene: prompt})