WEB API independent

This commit is contained in:
tuyang.yhj 2023-07-04 16:41:02 +08:00
parent 560773296e
commit 5f9c36a050
9 changed files with 50 additions and 24 deletions

View File

@ -1,10 +1,17 @@
import json
from pilot.common.sql_database import Database
from pilot.configs.config import Config
CFG = Config()
if __name__ == "__main__":
connect = CFG.local_db.get_session("gpt-user")
datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
# connect = CFG.local_db.get_session("gpt-user")
# datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
# print(datas)
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
print(str.find("["))
print(datas)

View File

@ -14,7 +14,6 @@ def generate_stream(
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)

View File

@ -113,25 +113,36 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\n", " ")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "")
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + resp_obj_ex["error_code"])
def __illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
for illegal_json_end in illegal_json_ends_1:
temp_json = temp_json.replace(illegal_json_end, " }")
for illegal_json_end in illegal_json_ends_2:
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s):
temp_json = self.__json_interception(s, True)
if not temp_json:
temp_json = self.__json_interception(s)
try:
json.loads(temp_json)
temp_json = self.__illegal_json_ends(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
if is_json_array:
i = s.index("[")
i = s.find("[")
if i <0:
return None
count = 1
@ -145,7 +156,7 @@ class BaseOutputParser(ABC):
assert count == 0
return s[i: j + 1]
else:
i = s.index("{")
i = s.find("{")
if i <0:
return None
count = 1
@ -189,6 +200,7 @@ class BaseOutputParser(ABC):
.replace("\\n", " ")
.replace("\\", " ")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(self, ai_text, data) -> str:

View File

@ -51,6 +51,9 @@ class PromptTemplate(BaseModel, ABC):
need_historical_messages: bool = False
temperature: float = 0.6
max_new_tokens: int = 1024
class Config:
"""Configuration for this pydantic object."""

View File

@ -48,8 +48,6 @@ CFG = Config()
class BaseChat(ABC):
chat_scene: str = None
llm_model: Any = None
temperature: float = 0.6
max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 1
@ -117,9 +115,9 @@ class BaseChat(ABC):
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
"temperature": float(self.temperature),
"max_new_tokens": int(self.max_new_tokens),
"prompt": self.generate_llm_text().replace("ai:", "assistant:"),
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep,
}
return payload
@ -128,6 +126,7 @@ class BaseChat(ABC):
# TODO Retry when server connection error
payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""

View File

@ -3,7 +3,7 @@
"name": "sale_report",
"introduce": "",
"layout": "TODO",
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart", "Scatterplot", "IndicatorValue", "Table"],
"supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "IndicatorValue"],
"key_metrics":[],
"trends": []
}

View File

@ -14,8 +14,8 @@ EXAMPLES = [
\"sql\": \"SELECT city FROM users where user_name='test1'\",
}""",
"example": True,
},
},
}
}
]
},
{
@ -29,10 +29,10 @@ EXAMPLES = [
\"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""",
"example": True,
},
},
}
}
]
},
}
]
sql_data_example = ExampleSelector(

View File

@ -10,7 +10,6 @@ CFG = Config()
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, create a syntactically correct {dialect} query.
@ -36,6 +35,11 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
@ -47,5 +51,7 @@ prompt = PromptTemplate(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
example_selector=sql_data_example,
# example_selector=None,
temperature=PROMPT_TEMPERATURE
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -14,8 +14,8 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
}
}
]
},
{
@ -30,10 +30,10 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
}
}
]
},
}
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)