From 5f9c36a0506134d0132232c07091bb73438f3d42 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Tue, 4 Jul 2023 16:41:02 +0800 Subject: [PATCH 1/3] WEB API independent --- pilot/connections/rdbms/py_study/study_data.py | 13 ++++++++++--- pilot/model/llm_out/vicuna_base_llm.py | 1 - pilot/out_parser/base.py | 18 +++++++++++++++--- pilot/prompts/prompt_new.py | 3 +++ pilot/scene/base_chat.py | 9 ++++----- .../template/sales_report/dashboard.json | 2 +- pilot/scene/chat_db/auto_execute/example.py | 10 +++++----- pilot/scene/chat_db/auto_execute/prompt.py | 8 +++++++- pilot/scene/chat_execution/example.py | 10 +++++----- 9 files changed, 50 insertions(+), 24 deletions(-) diff --git a/pilot/connections/rdbms/py_study/study_data.py b/pilot/connections/rdbms/py_study/study_data.py index c83c52acc..6d6b88e04 100644 --- a/pilot/connections/rdbms/py_study/study_data.py +++ b/pilot/connections/rdbms/py_study/study_data.py @@ -1,10 +1,17 @@ +import json from pilot.common.sql_database import Database from pilot.configs.config import Config CFG = Config() if __name__ == "__main__": - connect = CFG.local_db.get_session("gpt-user") - datas = CFG.local_db.run(connect, "SELECT * FROM users; ") + # connect = CFG.local_db.get_session("gpt-user") + # datas = CFG.local_db.run(connect, "SELECT * FROM users; ") + + # print(datas) + + str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}""" + + print(str.find("[")) + - print(datas) diff --git a/pilot/model/llm_out/vicuna_base_llm.py b/pilot/model/llm_out/vicuna_base_llm.py index 042f9954b..30033860b 100644 --- a/pilot/model/llm_out/vicuna_base_llm.py +++ b/pilot/model/llm_out/vicuna_base_llm.py @@ -14,7 +14,6 @@ def generate_stream( temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) - input_ids = tokenizer(prompt).input_ids output_ids = list(input_ids) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 3631d12ee..cec51a9c5 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -113,25 +113,36 @@ class BaseOutputParser(ABC): ai_response = ai_response.replace("\n", " ") ai_response = ai_response.replace("\_", "_") ai_response = ai_response.replace("\*", "*") + ai_response = ai_response.replace("\t", "") print("un_stream ai response:", ai_response) return ai_response else: raise ValueError("Model server error!code=" + resp_obj_ex["error_code"]) + def __illegal_json_ends(self, s): + temp_json = s + illegal_json_ends_1 = [", }", ",}"] + illegal_json_ends_2 = ", ]", ",]" + for illegal_json_end in illegal_json_ends_1: + temp_json = temp_json.replace(illegal_json_end, " }") + for illegal_json_end in illegal_json_ends_2: + temp_json = temp_json.replace(illegal_json_end, " ]") + return temp_json + def __extract_json(self, s): temp_json = self.__json_interception(s, True) if not temp_json: temp_json = self.__json_interception(s) try: - json.loads(temp_json) + temp_json = self.__illegal_json_ends(temp_json) return temp_json except Exception as e: raise ValueError("Failed to find a valid json response!" + temp_json) def __json_interception(self, s, is_json_array: bool = False): if is_json_array: - i = s.index("[") + i = s.find("[") if i <0: return None count = 1 @@ -145,7 +156,7 @@ class BaseOutputParser(ABC): assert count == 0 return s[i: j + 1] else: - i = s.index("{") + i = s.find("{") if i <0: return None count = 1 @@ -189,6 +200,7 @@ class BaseOutputParser(ABC): .replace("\\n", " ") .replace("\\", " ") ) + cleaned_output = self.__illegal_json_ends(cleaned_output) return cleaned_output def parse_view_response(self, ai_text, data) -> str: diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 80f05c730..78e1585ea 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -51,6 +51,9 @@ class PromptTemplate(BaseModel, ABC): need_historical_messages: bool = False + temperature: float = 0.6 + max_new_tokens: int = 1024 + class Config: """Configuration for this pydantic object.""" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 33553cb30..6eecbbc71 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -48,8 +48,6 @@ CFG = Config() class BaseChat(ABC): chat_scene: str = None llm_model: Any = None - temperature: float = 0.6 - max_new_tokens: int = 1024 # By default, keep the last two rounds of conversation records as the context chat_retention_rounds: int = 1 @@ -117,9 +115,9 @@ class BaseChat(ABC): payload = { "model": self.llm_model, - "prompt": self.generate_llm_text(), - "temperature": float(self.temperature), - "max_new_tokens": int(self.max_new_tokens), + "prompt": self.generate_llm_text().replace("ai:", "assistant:"), + "temperature": float(self.prompt_template.temperature), + "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, } return payload @@ -128,6 +126,7 @@ class BaseChat(ABC): # TODO Retry when server connection error payload = self.__call_base() + self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" diff --git a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json index 5a7f836e9..7301c5669 100644 --- a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json +++ b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json @@ -3,7 +3,7 @@ "name": "sale_report", "introduce": "", "layout": "TODO", - "supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart", "Scatterplot", "IndicatorValue", "Table"], + "supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "IndicatorValue"], "key_metrics":[], "trends": [] } \ No newline at end of file diff --git a/pilot/scene/chat_db/auto_execute/example.py b/pilot/scene/chat_db/auto_execute/example.py index 73fea6f51..9f877fdad 100644 --- a/pilot/scene/chat_db/auto_execute/example.py +++ b/pilot/scene/chat_db/auto_execute/example.py @@ -14,8 +14,8 @@ EXAMPLES = [ \"sql\": \"SELECT city FROM users where user_name='test1'\", }""", "example": True, - }, - }, + } + } ] }, { @@ -29,10 +29,10 @@ EXAMPLES = [ \"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", }""", "example": True, - }, - }, + } + } ] - }, + } ] sql_data_example = ExampleSelector( diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 5b7174638..1db3d7791 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -10,7 +10,6 @@ CFG = Config() PROMPT_SCENE_DEFINE = None - _DEFAULT_TEMPLATE = """ You are a SQL expert. Given an input question, create a syntactically correct {dialect} query. @@ -36,6 +35,11 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False +# Temperature is a configuration hyperparameter that controls the randomness of language model output. +# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. +# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. +PROMPT_TEMPERATURE = 0.5 + prompt = PromptTemplate( template_scene=ChatScene.ChatWithDbExecute.value(), input_variables=["input", "table_info", "dialect", "top_k", "response"], @@ -47,5 +51,7 @@ prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), example_selector=sql_data_example, + # example_selector=None, + temperature=PROMPT_TEMPERATURE ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 35f0c634a..da33e5378 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -14,8 +14,8 @@ EXAMPLES = [ \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, }""", "example": True, - }, - }, + } + } ] }, { @@ -30,10 +30,10 @@ EXAMPLES = [ \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, }""", "example": True, - }, - }, + } + } ] - }, + } ] plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) From e8c61c29e2393bb84f7c7a28a618020bcd7dd4e9 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Tue, 4 Jul 2023 16:50:49 +0800 Subject: [PATCH 2/3] WEB API independent --- pilot/configs/config.py | 1 + .../scene/chat_dashboard/template/sales_report/dashboard.json | 2 +- pilot/scene/chat_db/auto_execute/example.py | 4 ++-- pilot/scene/chat_db/auto_execute/out_parser.py | 2 +- pilot/server/dbgpt_server.py | 3 ++- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 0be6f18fc..804176357 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -18,6 +18,7 @@ class Config(metaclass=Singleton): """Initialize the Config class""" self.NEW_SERVER_MODE = False + self.SERVER_LIGHT_MODE = False # Gradio language version: en, zh self.LANGUAGE = os.getenv("LANGUAGE", "en") diff --git a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json index 7301c5669..ab1e7abdf 100644 --- a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json +++ b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json @@ -3,7 +3,7 @@ "name": "sale_report", "introduce": "", "layout": "TODO", - "supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "IndicatorValue"], + "supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "KeyMetrics"], "key_metrics":[], "trends": [] } \ No newline at end of file diff --git a/pilot/scene/chat_db/auto_execute/example.py b/pilot/scene/chat_db/auto_execute/example.py index 9f877fdad..1e9e3e15f 100644 --- a/pilot/scene/chat_db/auto_execute/example.py +++ b/pilot/scene/chat_db/auto_execute/example.py @@ -11,7 +11,7 @@ EXAMPLES = [ "data": { "content": """{ \"thoughts\": \"thought text\", - \"sql\": \"SELECT city FROM users where user_name='test1'\", + \"sql\": \"SELECT city FROM user where user_name='test1'\", }""", "example": True, } @@ -26,7 +26,7 @@ EXAMPLES = [ "data": { "content": """{ \"thoughts\": \"thought text\", - \"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", + \"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", }""", "example": True, } diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 9b02d1ba1..c7237671b 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -35,7 +35,7 @@ class DbChatOutputParser(BaseOutputParser): if len(data) <= 1: data.insert(0, ["result"]) df = pd.DataFrame(data[1:], columns=data[0]) - if not CFG.NEW_SERVER_MODE: + if not CFG.NEW_SERVER_MODE and not CFG.SERVER_LIGHT_MODE: table_style = """""" diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 0279a2b86..032b2b896 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -103,7 +103,8 @@ if __name__ == "__main__": from pilot.server.llmserver import worker worker.start_check() CFG.NEW_SERVER_MODE = True - + else: + CFG.SERVER_LIGHT_MODE = True import uvicorn uvicorn.run(app, host="0.0.0.0", port=args.port) From 314920b6e12031f60b9fbf8c15916b83ff48da77 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Tue, 4 Jul 2023 17:19:18 +0800 Subject: [PATCH 3/3] WEB API independent --- pilot/model/llm_out/vicuna_base_llm.py | 1 + pilot/scene/base_chat.py | 2 +- pilot/scene/chat_db/auto_execute/out_parser.py | 5 +++-- pilot/scene/chat_db/auto_execute/prompt.py | 3 +-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pilot/model/llm_out/vicuna_base_llm.py b/pilot/model/llm_out/vicuna_base_llm.py index 30033860b..d4fcaa33d 100644 --- a/pilot/model/llm_out/vicuna_base_llm.py +++ b/pilot/model/llm_out/vicuna_base_llm.py @@ -11,6 +11,7 @@ def generate_stream( """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py""" prompt = params["prompt"] l_prompt = len(prompt) + prompt= prompt.replace("ai:", "assistant:").replace("human:", "user:") temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 6eecbbc71..449df3fe4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -115,7 +115,7 @@ class BaseChat(ABC): payload = { "model": self.llm_model, - "prompt": self.generate_llm_text().replace("ai:", "assistant:"), + "prompt": self.generate_llm_text(), "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index c7237671b..64432520e 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -42,8 +42,9 @@ class DbChatOutputParser(BaseOutputParser): html_table = df.to_html(index=False, escape=False) html = f"{table_style}{html_table}" else: - html = df.to_html(index=False, escape=False, sparsify=False) - html = "".join(html.split()) + html_table = df.to_html(index=False, escape=False, sparsify=False) + table_str = "".join(html_table.split()) + html = f"""
{table_str}
""" view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") return view_text diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 1db3d7791..26d85cd61 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -11,7 +11,7 @@ CFG = Config() PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ -You are a SQL expert. Given an input question, create a syntactically correct {dialect} query. +You are a SQL expert. Given an input question, create a syntactically correct {dialect} sql. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. Use as few tables as possible when querying. @@ -51,7 +51,6 @@ prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), example_selector=sql_data_example, - # example_selector=None, temperature=PROMPT_TEMPERATURE ) CFG.prompt_templates.update({prompt.template_scene: prompt})