feat(ChatDB): ChatDB Use fintune model

1.Compatible with community pure sql output model
This commit is contained in:
yhjun1026 2023-11-13 16:28:30 +08:00
parent 4c2a033eec
commit f735b12c99
23 changed files with 45 additions and 203 deletions

View File

@ -341,6 +341,7 @@ class ApiCall:
if api_status.api_result: if api_status.api_result:
param["result"] = api_status.api_result param["result"] = api_status.api_result
return json.dumps(param, default=serialize, ensure_ascii=False) return json.dumps(param, default=serialize, ensure_ascii=False)
def to_view_text(self, api_status: PluginStatus): def to_view_text(self, api_status: PluginStatus):
@ -358,10 +359,13 @@ class ApiCall:
return html return html
else: else:
api_call_element = ET.Element("chart-view") api_call_element = ET.Element("chart-view")
api_call_element.text = self.__to_antv_vis_param(api_status) api_call_element.set("content", self.__to_antv_vis_param(api_status))
# api_call_element.text = self.__to_antv_vis_param(api_status)
result = ET.tostring(api_call_element, encoding="utf-8") result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8") return result.decode("utf-8")
# return f'<chart-view content="{self.__to_antv_vis_param(api_status)}">'
def __to_antv_vis_param(self, api_status: PluginStatus): def __to_antv_vis_param(self, api_status: PluginStatus):
param = {} param = {}
if api_status.name: if api_status.name:
@ -373,8 +377,9 @@ class ApiCall:
if api_status.api_result: if api_status.api_result:
param["data"] = api_status.api_result param["data"] = api_status.api_result
else:
return json.dumps(param, default=serialize, ensure_ascii=False, separators=(',', ':')) param["data"] =[]
return json.dumps(param, ensure_ascii=False)
def run(self, llm_text): def run(self, llm_text):
if self.__is_need_wait_plugin_call(llm_text): if self.__is_need_wait_plugin_call(llm_text):

View File

@ -13,7 +13,7 @@ class RemoteModelWorker(ModelWorker):
def __init__(self) -> None: def __init__(self) -> None:
self.headers = {} self.headers = {}
# TODO Configured by ModelParameters # TODO Configured by ModelParameters
self.timeout = 360 self.timeout = 3600
self.host = None self.host = None
self.port = None self.port = None

View File

@ -1,6 +1,7 @@
import json import json
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
import logging import logging
import sqlparse
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from pilot.common.json_utils import serialize from pilot.common.json_utils import serialize
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
@ -21,13 +22,27 @@ class SqlAction(NamedTuple):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DbChatOutputParser(BaseOutputParser): class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool): def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out) super().__init__(sep=sep, is_stream_out=is_stream_out)
def is_sql_statement(statement):
parsed = sqlparse.parse(statement)
if not parsed:
return False
for stmt in parsed:
if stmt.get_type() != 'UNKNOWN':
return True
return False
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)
logging.info("clean prompt response:", clean_str) logging.info("clean prompt response:", clean_str)
#Compatible with community pure sql output model
if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "")
else:
response = json.loads(clean_str) response = json.loads(clean_str)
for key in sorted(response): for key in sorted(response):
if key.strip() == "sql": if key.strip() == "sql":
@ -57,7 +72,8 @@ class DbChatOutputParser(BaseOutputParser):
err_msg = str(e) err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False) view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
api_call_element.text = view_json_str # api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8") result = ET.tostring(api_call_element, encoding="utf-8")
if err_msg: if err_msg:
return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}""" return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}"""

View File

@ -37,7 +37,8 @@ class DbDataLoader:
err_msg = str(e) err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False) view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
api_call_element.text = view_json_str # api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8") result = ET.tostring(api_call_element, encoding="utf-8")
if err_msg: if err_msg:
return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}""" return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}"""

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1 +0,0 @@
self.__BUILD_MANIFEST=function(s,c,e,a,t,n,b,d,k,h,i,u){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[b,s,c,e,n,d,"static/chunks/539-dcd22f1f6b99ebee.js","static/chunks/pages/index-2d0eeb81ae5a56b6.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/agent":[s,c,a,k,t,"static/chunks/pages/agent-ff3852522fea07c6.js"],"/chat":["static/chunks/pages/chat-be738dbebc61501d.js"],"/chat/[scene]/[id]":["static/chunks/pages/chat/[scene]/[id]-868f396254bd78ec.js"],"/database":[s,c,e,a,t,n,h,"static/chunks/643-d8f53f40dd3c5b40.js","static/chunks/pages/database-51b362b90c1b0be9.js"],"/knowledge":[i,s,c,a,k,t,n,"static/chunks/63-d9f1013be8e4599a.js","static/chunks/pages/knowledge-7a8d4a6321e572be.js"],"/knowledge/chunk":[a,t,"static/chunks/pages/knowledge/chunk-e27c2e349b868b28.js"],"/models":[i,s,c,e,u,h,"static/chunks/pages/models-97ac98a0a4bed459.js"],"/prompt":[b,s,c,e,u,"static/chunks/837-e6d4d1eb9e057050.js",d,"static/chunks/607-b224c640f6907e4b.js","static/chunks/pages/prompt-b28755caf89f9c30.js"],sortedPages:["/","/_app","/_error","/agent","/chat","/chat/[scene]/[id]","/database","/knowledge","/knowledge/chunk","/models","/prompt"]}}("static/chunks/64-91b49d45b9846775.js","static/chunks/479-b20198841f9a6a1e.js","static/chunks/9-bb2c54d5c06ba4bf.js","static/chunks/442-197e6cbc1e54109a.js","static/chunks/813-cce9482e33f2430c.js","static/chunks/411-d9eba2657c72f766.js","static/chunks/29107295-90b90cb30c825230.js","static/chunks/719-5a18c3c696beda6f.js","static/chunks/365-2cad3676ccbb1b1a.js","static/chunks/928-74244889bd7f2699.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/947-5980a3ff49069ddd.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -1 +0,0 @@
self.__SSG_MANIFEST=new Set([]);self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File