diff --git a/pilot/connections/rdbms/py_study/study_data.py b/pilot/connections/rdbms/py_study/study_data.py index c83c52acc..6d6b88e04 100644 --- a/pilot/connections/rdbms/py_study/study_data.py +++ b/pilot/connections/rdbms/py_study/study_data.py @@ -1,10 +1,17 @@ +import json from pilot.common.sql_database import Database from pilot.configs.config import Config CFG = Config() if __name__ == "__main__": - connect = CFG.local_db.get_session("gpt-user") - datas = CFG.local_db.run(connect, "SELECT * FROM users; ") + # connect = CFG.local_db.get_session("gpt-user") + # datas = CFG.local_db.run(connect, "SELECT * FROM users; ") + + # print(datas) + + str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}""" + + print(str.find("[")) + - print(datas) diff --git a/pilot/model/llm_out/vicuna_base_llm.py b/pilot/model/llm_out/vicuna_base_llm.py index 042f9954b..30033860b 100644 --- a/pilot/model/llm_out/vicuna_base_llm.py +++ b/pilot/model/llm_out/vicuna_base_llm.py @@ -14,7 +14,6 @@ def generate_stream( temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) - input_ids = tokenizer(prompt).input_ids output_ids = list(input_ids) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 3631d12ee..cec51a9c5 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -113,25 +113,36 @@ class BaseOutputParser(ABC): ai_response = ai_response.replace("\n", " ") ai_response = ai_response.replace("\_", "_") ai_response = ai_response.replace("\*", "*") + ai_response = ai_response.replace("\t", "") print("un_stream ai response:", ai_response) return ai_response else: raise ValueError("Model server error!code=" + resp_obj_ex["error_code"]) + def __illegal_json_ends(self, s): + temp_json = s + illegal_json_ends_1 = [", }", ",}"] + illegal_json_ends_2 = ", ]", ",]" + for illegal_json_end in illegal_json_ends_1: + temp_json = temp_json.replace(illegal_json_end, " }") + for illegal_json_end in illegal_json_ends_2: + temp_json = temp_json.replace(illegal_json_end, " ]") + return temp_json + def __extract_json(self, s): temp_json = self.__json_interception(s, True) if not temp_json: temp_json = self.__json_interception(s) try: - json.loads(temp_json) + temp_json = self.__illegal_json_ends(temp_json) return temp_json except Exception as e: raise ValueError("Failed to find a valid json response!" + temp_json) def __json_interception(self, s, is_json_array: bool = False): if is_json_array: - i = s.index("[") + i = s.find("[") if i <0: return None count = 1 @@ -145,7 +156,7 @@ class BaseOutputParser(ABC): assert count == 0 return s[i: j + 1] else: - i = s.index("{") + i = s.find("{") if i <0: return None count = 1 @@ -189,6 +200,7 @@ class BaseOutputParser(ABC): .replace("\\n", " ") .replace("\\", " ") ) + cleaned_output = self.__illegal_json_ends(cleaned_output) return cleaned_output def parse_view_response(self, ai_text, data) -> str: diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 80f05c730..78e1585ea 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -51,6 +51,9 @@ class PromptTemplate(BaseModel, ABC): need_historical_messages: bool = False + temperature: float = 0.6 + max_new_tokens: int = 1024 + class Config: """Configuration for this pydantic object.""" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 33553cb30..6eecbbc71 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -48,8 +48,6 @@ CFG = Config() class BaseChat(ABC): chat_scene: str = None llm_model: Any = None - temperature: float = 0.6 - max_new_tokens: int = 1024 # By default, keep the last two rounds of conversation records as the context chat_retention_rounds: int = 1 @@ -117,9 +115,9 @@ class BaseChat(ABC): payload = { "model": self.llm_model, - "prompt": self.generate_llm_text(), - "temperature": float(self.temperature), - "max_new_tokens": int(self.max_new_tokens), + "prompt": self.generate_llm_text().replace("ai:", "assistant:"), + "temperature": float(self.prompt_template.temperature), + "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, } return payload @@ -128,6 +126,7 @@ class BaseChat(ABC): # TODO Retry when server connection error payload = self.__call_base() + self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" diff --git a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json index 5a7f836e9..7301c5669 100644 --- a/pilot/scene/chat_dashboard/template/sales_report/dashboard.json +++ b/pilot/scene/chat_dashboard/template/sales_report/dashboard.json @@ -3,7 +3,7 @@ "name": "sale_report", "introduce": "", "layout": "TODO", - "supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart", "Scatterplot", "IndicatorValue", "Table"], + "supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "IndicatorValue"], "key_metrics":[], "trends": [] } \ No newline at end of file diff --git a/pilot/scene/chat_db/auto_execute/example.py b/pilot/scene/chat_db/auto_execute/example.py index 73fea6f51..9f877fdad 100644 --- a/pilot/scene/chat_db/auto_execute/example.py +++ b/pilot/scene/chat_db/auto_execute/example.py @@ -14,8 +14,8 @@ EXAMPLES = [ \"sql\": \"SELECT city FROM users where user_name='test1'\", }""", "example": True, - }, - }, + } + } ] }, { @@ -29,10 +29,10 @@ EXAMPLES = [ \"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", }""", "example": True, - }, - }, + } + } ] - }, + } ] sql_data_example = ExampleSelector( diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 5b7174638..1db3d7791 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -10,7 +10,6 @@ CFG = Config() PROMPT_SCENE_DEFINE = None - _DEFAULT_TEMPLATE = """ You are a SQL expert. Given an input question, create a syntactically correct {dialect} query. @@ -36,6 +35,11 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False +# Temperature is a configuration hyperparameter that controls the randomness of language model output. +# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. +# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. +PROMPT_TEMPERATURE = 0.5 + prompt = PromptTemplate( template_scene=ChatScene.ChatWithDbExecute.value(), input_variables=["input", "table_info", "dialect", "top_k", "response"], @@ -47,5 +51,7 @@ prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), example_selector=sql_data_example, + # example_selector=None, + temperature=PROMPT_TEMPERATURE ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 35f0c634a..da33e5378 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -14,8 +14,8 @@ EXAMPLES = [ \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, }""", "example": True, - }, - }, + } + } ] }, { @@ -30,10 +30,10 @@ EXAMPLES = [ \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, }""", "example": True, - }, - }, + } + } ] - }, + } ] plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)