mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-11 21:22:28 +00:00
Merge remote-tracking branch 'origin/dev_ty_06_end' into llm_framework
This commit is contained in:
commit
1efaa55515
@ -18,6 +18,7 @@ class Config(metaclass=Singleton):
|
|||||||
"""Initialize the Config class"""
|
"""Initialize the Config class"""
|
||||||
|
|
||||||
self.NEW_SERVER_MODE = False
|
self.NEW_SERVER_MODE = False
|
||||||
|
self.SERVER_LIGHT_MODE = False
|
||||||
|
|
||||||
# Gradio language version: en, zh
|
# Gradio language version: en, zh
|
||||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
|
import json
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
connect = CFG.local_db.get_session("gpt-user")
|
# connect = CFG.local_db.get_session("gpt-user")
|
||||||
datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
|
# datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
|
||||||
|
|
||||||
|
# print(datas)
|
||||||
|
|
||||||
|
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
|
||||||
|
|
||||||
|
print(str.find("["))
|
||||||
|
|
||||||
|
|
||||||
print(datas)
|
|
||||||
|
@ -11,10 +11,10 @@ def generate_stream(
|
|||||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
|
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
l_prompt = len(prompt)
|
l_prompt = len(prompt)
|
||||||
|
prompt= prompt.replace("ai:", "assistant:").replace("human:", "user:")
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
stop_str = params.get("stop", None)
|
stop_str = params.get("stop", None)
|
||||||
|
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
output_ids = list(input_ids)
|
output_ids = list(input_ids)
|
||||||
|
|
||||||
|
@ -113,25 +113,36 @@ class BaseOutputParser(ABC):
|
|||||||
ai_response = ai_response.replace("\n", " ")
|
ai_response = ai_response.replace("\n", " ")
|
||||||
ai_response = ai_response.replace("\_", "_")
|
ai_response = ai_response.replace("\_", "_")
|
||||||
ai_response = ai_response.replace("\*", "*")
|
ai_response = ai_response.replace("\*", "*")
|
||||||
|
ai_response = ai_response.replace("\t", "")
|
||||||
print("un_stream ai response:", ai_response)
|
print("un_stream ai response:", ai_response)
|
||||||
return ai_response
|
return ai_response
|
||||||
else:
|
else:
|
||||||
raise ValueError("Model server error!code=" + resp_obj_ex["error_code"])
|
raise ValueError("Model server error!code=" + resp_obj_ex["error_code"])
|
||||||
|
|
||||||
|
def __illegal_json_ends(self, s):
|
||||||
|
temp_json = s
|
||||||
|
illegal_json_ends_1 = [", }", ",}"]
|
||||||
|
illegal_json_ends_2 = ", ]", ",]"
|
||||||
|
for illegal_json_end in illegal_json_ends_1:
|
||||||
|
temp_json = temp_json.replace(illegal_json_end, " }")
|
||||||
|
for illegal_json_end in illegal_json_ends_2:
|
||||||
|
temp_json = temp_json.replace(illegal_json_end, " ]")
|
||||||
|
return temp_json
|
||||||
|
|
||||||
def __extract_json(self, s):
|
def __extract_json(self, s):
|
||||||
|
|
||||||
temp_json = self.__json_interception(s, True)
|
temp_json = self.__json_interception(s, True)
|
||||||
if not temp_json:
|
if not temp_json:
|
||||||
temp_json = self.__json_interception(s)
|
temp_json = self.__json_interception(s)
|
||||||
try:
|
try:
|
||||||
json.loads(temp_json)
|
temp_json = self.__illegal_json_ends(temp_json)
|
||||||
return temp_json
|
return temp_json
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Failed to find a valid json response!" + temp_json)
|
raise ValueError("Failed to find a valid json response!" + temp_json)
|
||||||
|
|
||||||
def __json_interception(self, s, is_json_array: bool = False):
|
def __json_interception(self, s, is_json_array: bool = False):
|
||||||
if is_json_array:
|
if is_json_array:
|
||||||
i = s.index("[")
|
i = s.find("[")
|
||||||
if i <0:
|
if i <0:
|
||||||
return None
|
return None
|
||||||
count = 1
|
count = 1
|
||||||
@ -145,7 +156,7 @@ class BaseOutputParser(ABC):
|
|||||||
assert count == 0
|
assert count == 0
|
||||||
return s[i: j + 1]
|
return s[i: j + 1]
|
||||||
else:
|
else:
|
||||||
i = s.index("{")
|
i = s.find("{")
|
||||||
if i <0:
|
if i <0:
|
||||||
return None
|
return None
|
||||||
count = 1
|
count = 1
|
||||||
@ -189,6 +200,7 @@ class BaseOutputParser(ABC):
|
|||||||
.replace("\\n", " ")
|
.replace("\\n", " ")
|
||||||
.replace("\\", " ")
|
.replace("\\", " ")
|
||||||
)
|
)
|
||||||
|
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
def parse_view_response(self, ai_text, data) -> str:
|
def parse_view_response(self, ai_text, data) -> str:
|
||||||
|
@ -51,6 +51,9 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
|
|
||||||
need_historical_messages: bool = False
|
need_historical_messages: bool = False
|
||||||
|
|
||||||
|
temperature: float = 0.6
|
||||||
|
max_new_tokens: int = 1024
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
@ -48,8 +48,6 @@ CFG = Config()
|
|||||||
class BaseChat(ABC):
|
class BaseChat(ABC):
|
||||||
chat_scene: str = None
|
chat_scene: str = None
|
||||||
llm_model: Any = None
|
llm_model: Any = None
|
||||||
temperature: float = 0.6
|
|
||||||
max_new_tokens: int = 1024
|
|
||||||
# By default, keep the last two rounds of conversation records as the context
|
# By default, keep the last two rounds of conversation records as the context
|
||||||
chat_retention_rounds: int = 1
|
chat_retention_rounds: int = 1
|
||||||
|
|
||||||
@ -118,8 +116,8 @@ class BaseChat(ABC):
|
|||||||
payload = {
|
payload = {
|
||||||
"model": self.llm_model,
|
"model": self.llm_model,
|
||||||
"prompt": self.generate_llm_text(),
|
"prompt": self.generate_llm_text(),
|
||||||
"temperature": float(self.temperature),
|
"temperature": float(self.prompt_template.temperature),
|
||||||
"max_new_tokens": int(self.max_new_tokens),
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
"stop": self.prompt_template.sep,
|
"stop": self.prompt_template.sep,
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
@ -128,6 +126,7 @@ class BaseChat(ABC):
|
|||||||
# TODO Retry when server connection error
|
# TODO Retry when server connection error
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
|
|
||||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
"name": "sale_report",
|
"name": "sale_report",
|
||||||
"introduce": "",
|
"introduce": "",
|
||||||
"layout": "TODO",
|
"layout": "TODO",
|
||||||
"supported_chart_type":["HeatMap","sheet", "LineChart", "PieChart", "BarChart", "Scatterplot", "IndicatorValue", "Table"],
|
"supported_chart_type":["FacetChart", "GaugeChart", "RadarChart", "Sheet", "LineChart", "PieChart", "BarChart", "PointChart", "KeyMetrics"],
|
||||||
"key_metrics":[],
|
"key_metrics":[],
|
||||||
"trends": []
|
"trends": []
|
||||||
}
|
}
|
@ -11,11 +11,11 @@ EXAMPLES = [
|
|||||||
"data": {
|
"data": {
|
||||||
"content": """{
|
"content": """{
|
||||||
\"thoughts\": \"thought text\",
|
\"thoughts\": \"thought text\",
|
||||||
\"sql\": \"SELECT city FROM users where user_name='test1'\",
|
\"sql\": \"SELECT city FROM user where user_name='test1'\",
|
||||||
}""",
|
}""",
|
||||||
"example": True,
|
"example": True,
|
||||||
},
|
}
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -26,13 +26,13 @@ EXAMPLES = [
|
|||||||
"data": {
|
"data": {
|
||||||
"content": """{
|
"content": """{
|
||||||
\"thoughts\": \"thought text\",
|
\"thoughts\": \"thought text\",
|
||||||
\"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
|
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
|
||||||
}""",
|
}""",
|
||||||
"example": True,
|
"example": True,
|
||||||
},
|
}
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
sql_data_example = ExampleSelector(
|
sql_data_example = ExampleSelector(
|
||||||
|
@ -35,15 +35,16 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
if len(data) <= 1:
|
if len(data) <= 1:
|
||||||
data.insert(0, ["result"])
|
data.insert(0, ["result"])
|
||||||
df = pd.DataFrame(data[1:], columns=data[0])
|
df = pd.DataFrame(data[1:], columns=data[0])
|
||||||
if not CFG.NEW_SERVER_MODE:
|
if not CFG.NEW_SERVER_MODE and not CFG.SERVER_LIGHT_MODE:
|
||||||
table_style = """<style>
|
table_style = """<style>
|
||||||
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
||||||
</style>"""
|
</style>"""
|
||||||
html_table = df.to_html(index=False, escape=False)
|
html_table = df.to_html(index=False, escape=False)
|
||||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||||
else:
|
else:
|
||||||
html = df.to_html(index=False, escape=False, sparsify=False)
|
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||||
html = "".join(html.split())
|
table_str = "".join(html_table.split())
|
||||||
|
html = f"""<div class="w-full overflow-auto">{table_str}</table></div>"""
|
||||||
|
|
||||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||||
return view_text
|
return view_text
|
||||||
|
@ -10,9 +10,8 @@ CFG = Config()
|
|||||||
|
|
||||||
PROMPT_SCENE_DEFINE = None
|
PROMPT_SCENE_DEFINE = None
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE = """
|
||||||
You are a SQL expert. Given an input question, create a syntactically correct {dialect} query.
|
You are a SQL expert. Given an input question, create a syntactically correct {dialect} sql.
|
||||||
|
|
||||||
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
||||||
Use as few tables as possible when querying.
|
Use as few tables as possible when querying.
|
||||||
@ -36,6 +35,11 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
|
|||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||||
|
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||||
|
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||||
|
PROMPT_TEMPERATURE = 0.5
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatWithDbExecute.value(),
|
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
@ -47,5 +51,6 @@ prompt = PromptTemplate(
|
|||||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
),
|
),
|
||||||
example_selector=sql_data_example,
|
example_selector=sql_data_example,
|
||||||
|
temperature=PROMPT_TEMPERATURE
|
||||||
)
|
)
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
@ -14,8 +14,8 @@ EXAMPLES = [
|
|||||||
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||||
}""",
|
}""",
|
||||||
"example": True,
|
"example": True,
|
||||||
},
|
}
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -30,10 +30,10 @@ EXAMPLES = [
|
|||||||
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||||
}""",
|
}""",
|
||||||
"example": True,
|
"example": True,
|
||||||
},
|
}
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
||||||
|
@ -103,7 +103,8 @@ if __name__ == "__main__":
|
|||||||
from pilot.server.llmserver import worker
|
from pilot.server.llmserver import worker
|
||||||
worker.start_check()
|
worker.start_check()
|
||||||
CFG.NEW_SERVER_MODE = True
|
CFG.NEW_SERVER_MODE = True
|
||||||
|
else:
|
||||||
|
CFG.SERVER_LIGHT_MODE = True
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
|
Loading…
Reference in New Issue
Block a user