Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework

This commit is contained in:
aries_ckt 2023-07-04 17:25:30 +08:00
commit 1efaa55515
12 changed files with 59 additions and 30 deletions

View File

@ -18,6 +18,7 @@ class Config(metaclass=Singleton):
"""Initialize the Config class"""
self.NEW_SERVER_MODE = False
self.SERVER_LIGHT_MODE = False
# Gradio language version: en, zh
self.LANGUAGE = os.getenv("LANGUAGE", "en")

View File

@ -1,10 +1,17 @@
import json
from pilot.common.sql_database import Database
from pilot.configs.config import Config
CFG = Config()
if __name__ == "__main__":
connect = CFG.local_db.get_session("gpt-user")
datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
# connect = CFG.local_db.get_session("gpt-user")
# datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
# print(datas)
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
print(str.find("["))
print(datas)

View File

@ -11,10 +11,10 @@ def generate_stream(
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
prompt= prompt.replace("ai:", "assistant:").replace("human:", "user:")
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)

View File

@ -113,25 +113,36 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\n", " ")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "")
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + resp_obj_ex["error_code"])
def __illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
for illegal_json_end in illegal_json_ends_1:
temp_json = temp_json.replace(illegal_json_end, " }")
for illegal_json_end in illegal_json_ends_2:
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s):
temp_json = self.__json_interception(s, True)
if not temp_json:
temp_json = self.__json_interception(s)
try:
json.loads(temp_json)
temp_json = self.__illegal_json_ends(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
if is_json_array:
i = s.index("[")
i = s.find("[")
if i <0:
return None
count = 1
@ -145,7 +156,7 @@ class BaseOutputParser(ABC):
assert count == 0
return s[i: j + 1]
else:
i = s.index("{")
i = s.find("{")
if i <0:
return None
count = 1
@ -189,6 +200,7 @@ class BaseOutputParser(ABC):
.replace("\\n", " ")
.replace("\\", " ")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(self, ai_text, data) -> str:

View File

@ -51,6 +51,9 @@ class PromptTemplate(BaseModel, ABC):
need_historical_messages: bool = False
temperature: float = 0.6
max_new_tokens: int = 1024
class Config:
"""Configuration for this pydantic object."""

View File

@ -48,8 +48,6 @@ CFG = Config()
class BaseChat(ABC):
chat_scene: str = None
llm_model: Any = None
temperature: float = 0.6
max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 1
@ -118,8 +116,8 @@ class BaseChat(ABC):
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
"temperature": float(self.temperature),
"max_new_tokens": int(self.max_new_tokens),
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep,
}
return payload
@ -128,6 +126,7 @@ class BaseChat(ABC):
# TODO Retry when server connection error
payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""

View File

@ -3,7 +3,7 @@
"name": "sale_report",
"introduce": "",
"layout": "TODO",
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart", "Scatterplot", "IndicatorValue", "Table"],
"supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "KeyMetrics"],
"key_metrics":[],
"trends": []
}

View File

@ -11,11 +11,11 @@ EXAMPLES = [
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"sql\": \"SELECT city FROM users where user_name='test1'\",
\"sql\": \"SELECT city FROM user where user_name='test1'\",
}""",
"example": True,
},
},
}
}
]
},
{
@ -26,13 +26,13 @@ EXAMPLES = [
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""",
"example": True,
},
},
}
}
]
},
}
]
sql_data_example = ExampleSelector(

View File

@ -35,15 +35,16 @@ class DbChatOutputParser(BaseOutputParser):
if len(data) <= 1:
data.insert(0, ["result"])
df = pd.DataFrame(data[1:], columns=data[0])
if not CFG.NEW_SERVER_MODE:
if not CFG.NEW_SERVER_MODE and not CFG.SERVER_LIGHT_MODE:
table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
else:
html = df.to_html(index=False, escape=False, sparsify=False)
html = "".join(html.split())
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</table></div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text

View File

@ -10,9 +10,8 @@ CFG = Config()
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, create a syntactically correct {dialect} query.
You are a SQL expert. Given an input question, create a syntactically correct {dialect} sql.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
Use as few tables as possible when querying.
@ -36,6 +35,11 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
@ -47,5 +51,6 @@ prompt = PromptTemplate(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -14,8 +14,8 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
}
}
]
},
{
@ -30,10 +30,10 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
}
}
]
},
}
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@ -103,7 +103,8 @@ if __name__ == "__main__":
from pilot.server.llmserver import worker
worker.start_check()
CFG.NEW_SERVER_MODE = True
else:
CFG.SERVER_LIGHT_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port)