WEB API independent

This commit is contained in:
tuyang.yhj 2023-07-04 16:41:02 +08:00
parent 560773296e
commit 5f9c36a050
9 changed files with 50 additions and 24 deletions

View File

@ -1,10 +1,17 @@
import json
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
CFG = Config() CFG = Config()
if __name__ == "__main__": if __name__ == "__main__":
connect = CFG.local_db.get_session("gpt-user") # connect = CFG.local_db.get_session("gpt-user")
datas = CFG.local_db.run(connect, "SELECT * FROM users; ") # datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
# print(datas)
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
print(str.find("["))
print(datas)

View File

@ -14,7 +14,6 @@ def generate_stream(
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048)) max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None) stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids) output_ids = list(input_ids)

View File

@ -113,25 +113,36 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\n", " ") ai_response = ai_response.replace("\n", " ")
ai_response = ai_response.replace("\_", "_") ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*") ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "")
print("un_stream ai response:", ai_response) print("un_stream ai response:", ai_response)
return ai_response return ai_response
else: else:
raise ValueError("Model server error!code=" + resp_obj_ex["error_code"]) raise ValueError("Model server error!code=" + resp_obj_ex["error_code"])
def __illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
for illegal_json_end in illegal_json_ends_1:
temp_json = temp_json.replace(illegal_json_end, " }")
for illegal_json_end in illegal_json_ends_2:
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s): def __extract_json(self, s):
temp_json = self.__json_interception(s, True) temp_json = self.__json_interception(s, True)
if not temp_json: if not temp_json:
temp_json = self.__json_interception(s) temp_json = self.__json_interception(s)
try: try:
json.loads(temp_json) temp_json = self.__illegal_json_ends(temp_json)
return temp_json return temp_json
except Exception as e: except Exception as e:
raise ValueError("Failed to find a valid json response" + temp_json) raise ValueError("Failed to find a valid json response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False): def __json_interception(self, s, is_json_array: bool = False):
if is_json_array: if is_json_array:
i = s.index("[") i = s.find("[")
if i <0: if i <0:
return None return None
count = 1 count = 1
@ -145,7 +156,7 @@ class BaseOutputParser(ABC):
assert count == 0 assert count == 0
return s[i: j + 1] return s[i: j + 1]
else: else:
i = s.index("{") i = s.find("{")
if i <0: if i <0:
return None return None
count = 1 count = 1
@ -189,6 +200,7 @@ class BaseOutputParser(ABC):
.replace("\\n", " ") .replace("\\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output return cleaned_output
def parse_view_response(self, ai_text, data) -> str: def parse_view_response(self, ai_text, data) -> str:

View File

@ -51,6 +51,9 @@ class PromptTemplate(BaseModel, ABC):
need_historical_messages: bool = False need_historical_messages: bool = False
temperature: float = 0.6
max_new_tokens: int = 1024
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -48,8 +48,6 @@ CFG = Config()
class BaseChat(ABC): class BaseChat(ABC):
chat_scene: str = None chat_scene: str = None
llm_model: Any = None llm_model: Any = None
temperature: float = 0.6
max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context # By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 1 chat_retention_rounds: int = 1
@ -117,9 +115,9 @@ class BaseChat(ABC):
payload = { payload = {
"model": self.llm_model, "model": self.llm_model,
"prompt": self.generate_llm_text(), "prompt": self.generate_llm_text().replace("ai:", "assistant:"),
"temperature": float(self.temperature), "temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.max_new_tokens), "max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep, "stop": self.prompt_template.sep,
} }
return payload return payload
@ -128,6 +126,7 @@ class BaseChat(ABC):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""

View File

@ -3,7 +3,7 @@
"name": "sale_report", "name": "sale_report",
"introduce": "", "introduce": "",
"layout": "TODO", "layout": "TODO",
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart", "Scatterplot", "IndicatorValue", "Table"], "supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "IndicatorValue"],
"key_metrics":[], "key_metrics":[],
"trends": [] "trends": []
} }

View File

@ -14,8 +14,8 @@ EXAMPLES = [
\"sql\": \"SELECT city FROM users where user_name='test1'\", \"sql\": \"SELECT city FROM users where user_name='test1'\",
}""", }""",
"example": True, "example": True,
}, }
}, }
] ]
}, },
{ {
@ -29,10 +29,10 @@ EXAMPLES = [
\"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", \"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""", }""",
"example": True, "example": True,
}, }
}, }
] ]
}, }
] ]
sql_data_example = ExampleSelector( sql_data_example = ExampleSelector(

View File

@ -10,7 +10,6 @@ CFG = Config()
PROMPT_SCENE_DEFINE = None PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, create a syntactically correct {dialect} query. You are a SQL expert. Given an input question, create a syntactically correct {dialect} query.
@ -36,6 +35,11 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(), template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"], input_variables=["input", "table_info", "dialect", "top_k", "response"],
@ -47,5 +51,7 @@ prompt = PromptTemplate(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
), ),
example_selector=sql_data_example, example_selector=sql_data_example,
# example_selector=None,
temperature=PROMPT_TEMPERATURE
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -14,8 +14,8 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""", }""",
"example": True, "example": True,
}, }
}, }
] ]
}, },
{ {
@ -30,10 +30,10 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""", }""",
"example": True, "example": True,
}, }
}, }
] ]
}, }
] ]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)