feat(editor): editor api devlop

editor api devlop part 1
This commit is contained in:
yhjun1026 2023-08-10 15:01:00 +08:00
parent de3bcd1f68
commit a95d292bf9
17 changed files with 373 additions and 157 deletions

View File

@ -250,17 +250,17 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message"""
return f"Error: {e}"
def __write(self, session, write_sql):
def __write(self, write_sql):
print(f"Write[{write_sql}]")
db_cache = self._engine.url.database
result = session.execute(text(write_sql))
session.commit()
result = self.session.execute(text(write_sql))
self.session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
self.session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self, session, query, fetch: str = "all"):
def __query(self,query, fetch: str = "all"):
"""
only for query
Args:
@ -274,7 +274,7 @@ class RDBMSDatabase(BaseConnect):
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
cursor = self.session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
@ -288,7 +288,7 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
return result
def query_ex(self, session, query, fetch: str = "all"):
def query_ex(self, query, fetch: str = "all"):
"""
only for query
Args:
@ -300,7 +300,7 @@ class RDBMSDatabase(BaseConnect):
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
cursor = self.session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
@ -313,7 +313,7 @@ class RDBMSDatabase(BaseConnect):
result = list(result)
return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List:
def run(self, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
if not command:
@ -321,17 +321,17 @@ class RDBMSDatabase(BaseConnect):
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(session, command, fetch)
return self.__query( command, fetch)
else:
self.__write(session, command)
self.__write( command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
return self.__query( select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
cursor = self.session.execute(text(command))
self.session.commit()
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
@ -339,10 +339,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
return self.__query( f"SHOW COLUMNS FROM {table_name}")
return result
else:
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
return self.__query( f"SHOW COLUMNS FROM {table_name}")
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.

View File

@ -94,6 +94,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.commit()
self.connect.commit()
def update(self, messages:List[OnceConversation]) -> None:
cursor = self.connect.cursor()
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(messages, ensure_ascii=False), self.chat_seesion_id],
)
cursor.commit()
self.connect.commit()
def clear(self) -> None:
cursor = self.connect.cursor()
cursor.execute(
@ -134,6 +144,24 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return []
def conv_info(self, conv_uid: str = None) -> None:
cursor = self.connect.cursor()
cursor.execute(
"SELECT * FROM chat_history where conv_uid=? ",
[conv_uid],
)
# 获取查询结果字段名
fields = [field[0] for field in cursor.description]
for row in cursor.fetchone():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
return row_dict
return {}
def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute(

View File

@ -1,17 +1,12 @@
import os
import json
from fastapi import (
APIRouter,
Request,
Body,
BackgroundTasks,
)
from typing import List
from pilot.configs.config import Config
from pilot.server.knowledge.service import KnowledgeService
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
@ -19,9 +14,6 @@ from pilot.utils import build_logger
from pilot.openapi.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.openapi.editor_view_model import (
ChatDbRounds,
@ -31,9 +23,11 @@ from pilot.openapi.editor_view_model import (
DbTable
)
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData
from pilot.openapi.api_v1.editor.sql_editor import DataNode,ChartRunData,SqlRunData
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_db.auto_execute.out_parser import SqlAction
router = APIRouter()
CFG = Config()
@ -42,47 +36,165 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str= ""):
return Result.succ(None)
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str)
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
tables = db_conn.get_table_names()
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
for table in tables:
table_node: DataNode = DataNode(title=table, key=table, type="table")
db_node.children.append(table_node)
fields = db_conn.get_fields("transaction_order")
for field in fields:
table_node.children.append(
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3],
comment=field[-1]))
return Result.succ(db_node)
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
async def get_editor_sql_rounds(con_uid: str):
return Result.succ(None)
logger.info("get_editor_sql_rounds:{}", con_uid)
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
result: List = []
for once in history_messages:
round_name: str = ""
for element in once["messages"]:
if element["type"] == "human":
round_name = element["data"]["content"]
if once.get("param_value"):
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"],
round_name=round_name)
result.append(round)
return Result.succ(result)
@router.get("/v1/editor/sql", response_model=Result[SqlAction])
@router.get("/v1/editor/sql", response_model=Result[dict])
async def get_editor_sql(con_uid: str, round: int):
return Result.succ(None)
@router.get("/v1/editor/chart/details", response_model=Result[ChartDetail])
async def get_editor_sql_rounds(con_uid: str):
return Result.succ(None)
@router.get("/v1/editor/chart", response_model=Result[ChartDetail])
async def get_editor_chart(con_uid: str, chart_uid: str):
return Result.succ(None)
logger.info("get_editor_sql:{},{}", con_uid, round)
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
if int(once["chat_order"]) == round:
for element in once["messages"]:
if element["type"] == "ai":
return Result.succ(json.loads(element["data"]["content"]))
return Result.faild("没有获取到可用的SQL返回结构")
@router.post("/v1/editor/sql/run", response_model=Result[List[dict]])
async def get_editor_chart(db_name: str, sql: str):
return Result.succ(None)
async def editor_sql_run(db_name: str, sql: str):
logger.info("get_editor_sql_run:{},{}", db_name, sql)
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
return Result.succ(conn.run(sql))
@router.post("/v1/editor/chart/run", response_model=Result[ChartData])
async def get_editor_chart(db_name: str, sql: str):
return Result.succ(None)
@router.post("/v1/sql/editor/submit", response_model=Result)
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
if edit_round:
for element in edit_round["messages"]:
if element["type"] == "ai":
element["data"]["content"]=""
if element["type"] == "view":
element["data"]["content"]=""
history_mem.update(history_messages)
return Result.succ(None)
return Result.faild("Edit Faild!")
@router.get("/v1/editor/chart/list", response_model=Result[ChartDetail])
async def get_editor_chart_list(con_uid: str):
logger.info("get_editor_sql_rounds:{}", con_uid)
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order'])
for element in last_round["messages"]:
if element["type"] == "ai":
return Result.succ(json.loads(element["data"]["content"]))
return Result.faild("没有获取到可用的SQL返回结构")
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
async def get_editor_chart_info(con_uid: str, chart_uid: str):
logger.info(f"get_editor_sql_rounds:{con_uid}")
logger.info("get_editor_sql_rounds:{}", con_uid)
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order'])
db_name = last_round["param_value"]
if not db_name:
logger.error("this dashboard dialogue version too old, can't support editor!")
return Result.faild("this dashboard dialogue version too old, can't support editor!")
for element in last_round["messages"]:
if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]);
charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_uid, charts))[0]
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'],
chart_type=find_chart['chart_type'],
chart_desc=find_chart['chart_desc'],
chart_sql=find_chart['chart_sql'],
db_name=db_name,
chart_name=find_chart['chart_name'],
chart_value=find_chart['values'],
table_value=conn.run(find_chart['chart_sql'])
)
return Result.succ(detail)
return Result.faild("Can't Find Chart Detail Info!")
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(db_name: str, sql: str):
logger.info(f"editor_chart_run:{db_name},{sql}")
dashboard_data_loader:DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
field_names,chart_values = dashboard_data_loader.get_chart_values_by_db(db_conn, sql)
sql_run_data:SqlRunData = SqlRunData(result_info="",
run_cost="",
colunms= field_names,
values= db_conn.query_ex(sql)
)
return Result.succ(ChartRunData(sql_data=sql_run_data,chart_values=chart_values))
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
return Result.succ(None)
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
edit_round = list(filter(lambda x: x['chat_order'] == chart_edit_context.conv_round, history_messages))[0]
if edit_round:
for element in edit_round["messages"]:
if element["type"] == "ai":
view_data: dict = json.loads(element["data"]["content"]);
charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
@router.post("/v1/sql/editor/submit", response_model=Result[bool])
async def chart_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
return Result.succ(None)
if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]);
charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
history_mem.update(history_messages)
return Result.succ(None)
return Result.faild("Edit Faild!")

View File

@ -0,0 +1,26 @@
from typing import List
from pydantic import BaseModel, Field, root_validator, validator, Extra
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
class DataNode(BaseModel):
title: str
key: str
type: str = ""
default_value: str = None
can_null: str = 'YES'
comment: str = None
children: List = []
class SqlRunData(BaseModel):
result_info: str
run_cost: str
colunms: List[str]
values: List
class ChartRunData(BaseModel):
sql_data: SqlRunData
chart_values: List[ValueItem]

View File

@ -24,10 +24,12 @@ class ChatDbRounds(BaseModel):
class ChartDetail(BaseModel):
chart_uid: str
chart_type: str
chart_desc: str
chart_sql: str
db_name: str
chart_name: str
chart_value: str
chat_round: int # defualt last round
chart_value: Any
table_value: Any
class ChatChartEditContext(BaseModel):

View File

@ -60,10 +60,11 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self,
chat_mode,
chat_session_id,
current_user_input,
self,
chat_mode,
chat_session_id,
current_user_input,
select_param: Any = None
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
@ -87,6 +88,9 @@ class BaseChat(ABC):
)
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
if select_param:
self.current_message.param_type = chat_mode.param_types()[0]
self.current_message.param_value = select_param
self.current_tokens_used: int = 0
class Config:
@ -111,6 +115,24 @@ class BaseChat(ABC):
def do_action(self, prompt_response):
return prompt_response
def get_llm_speak(self, prompt_define_response):
if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict):
if "speak" in prompt_define_response.thoughts:
speak_to_user = prompt_define_response.thoughts.get("speak")
else:
speak_to_user = str(prompt_define_response.thoughts)
else:
if hasattr(prompt_define_response.thoughts, "speak"):
speak_to_user = prompt_define_response.thoughts.get("speak")
elif hasattr(prompt_define_response.thoughts, "reasoning"):
speak_to_user = prompt_define_response.thoughts.get("reasoning")
else:
speak_to_user = prompt_define_response.thoughts
else:
speak_to_user = prompt_define_response
return speak_to_user
def __call_base(self):
input_values = self.generate_input_values()
### Chat sequence advance
@ -209,26 +231,13 @@ class BaseChat(ABC):
ai_response_text
)
)
### sql run
result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict):
if "speak" in prompt_define_response.thoughts:
speak_to_user = prompt_define_response.thoughts.get("speak")
else:
speak_to_user = str(prompt_define_response.thoughts)
else:
if hasattr(prompt_define_response.thoughts, "speak"):
speak_to_user = prompt_define_response.thoughts.get("speak")
elif hasattr(prompt_define_response.thoughts, "reasoning"):
speak_to_user = prompt_define_response.thoughts.get("reasoning")
else:
speak_to_user = prompt_define_response.thoughts
else:
speak_to_user = prompt_define_response
view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
self.current_message.add_view_message(view_message)
except Exception as e:
print(traceback.format_exc())
@ -297,7 +306,7 @@ class BaseChat(ABC):
system_messages = []
for system_conv in system_convs:
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
system_messages.append(
ModelMessage(role=system_conv.type, content=system_conv.content)
@ -309,7 +318,7 @@ class BaseChat(ABC):
user_messages = []
if user_conv:
user_text = (
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
)
user_messages.append(
ModelMessage(role=user_conv.type, content=user_conv.content)
@ -331,10 +340,10 @@ class BaseChat(ABC):
message_type = round_message["type"]
message_content = round_message["data"]["content"]
example_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
example_messages.append(
ModelMessage(role=message_type, content=message_content)
@ -358,10 +367,10 @@ class BaseChat(ABC):
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
@ -377,10 +386,10 @@ class BaseChat(ABC):
message_type = round_message["type"]
message_content = round_message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
@ -398,10 +407,10 @@ class BaseChat(ABC):
message_type = message["type"]
message_content = message["data"]["content"]
history_text += (
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
message_type
+ ":"
+ message_content
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)

View File

@ -95,6 +95,7 @@ class ModelMessageRoleType:
SYSTEM = "system"
HUMAN = "human"
AI = "ai"
VIEW = "view"
class Generation(BaseModel):
@ -166,7 +167,7 @@ def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
def _parse_model_messages(
messages: List[ModelMessage],
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
""" "
Parameters:

View File

@ -1,26 +1,16 @@
import json
import os
import uuid
from typing import Dict, NamedTuple, List
from decimal import Decimal
from typing import List
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
ValueItem,
)
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
CFG = Config()
@ -85,42 +75,10 @@ class ChatDashboard(BaseChat):
def do_action(self, prompt_response):
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
chart_datas: List[ChartData] = []
dashboard_data_loader = DashboardDataLoader()
for chart_item in prompt_response:
try:
field_names, datas = self.database.query_ex(
self.db_connect, chart_item.sql
)
values: List[ValueItem] = []
data_map = {}
field_map = {}
index = 0
for field_name in field_names:
data_map.update({f"{field_name}": [row[index] for row in datas]})
index += 1
if not data_map[field_name]:
field_map.update({f"{field_name}": False})
else:
field_map.update(
{
f"{field_name}": all(
isinstance(item, (int, float, Decimal))
for item in data_map[field_name]
)
}
)
for field_name in field_names[1:]:
if not field_map[field_name]:
print("more than 2 non-numeric column")
else:
for data in datas:
value_item = ValueItem(
name=data[0],
type=field_name,
value=data[field_names.index(field_name)],
)
values.append(value_item)
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.db_connect, chart_item.sql)
chart_datas.append(
ChartData(
chart_uid=str(uuid.uuid1()),
@ -135,7 +93,6 @@ class ChatDashboard(BaseChat):
except Exception as e:
# TODO 修复流程
print(str(e))
return ReportData(
conv_uid=self.chat_session_id,
template_name=self.report_name,

View File

@ -0,0 +1,57 @@
from typing import List
from decimal import Decimal
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
CFG = Config()
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
class DashboardDataLoader:
def get_chart_values_by_conn(self, db_conn, chart_sql: str) :
logger.info(f"get_chart_values_by_conn:{chart_sql}")
try:
field_names, datas = db_conn.query_ex(chart_sql)
values: List[ValueItem] = []
data_map = {}
field_map = {}
index = 0
for field_name in field_names:
data_map.update({f"{field_name}": [row[index] for row in datas]})
index += 1
if not data_map[field_name]:
field_map.update({f"{field_name}": False})
else:
field_map.update(
{
f"{field_name}": all(
isinstance(item, (int, float, Decimal))
for item in data_map[field_name]
)
}
)
for field_name in field_names[1:]:
if not field_map[field_name]:
logger.info("More than 2 non-numeric column:" + field_name)
else:
for data in datas:
value_item = ValueItem(
name=data[0],
type=field_name,
value=data[field_names.index(field_name)],
)
values.append(value_item)
return field_names, values
except Exception as e:
logger.debug("Prepare Chart Data Faild!" + str(e))
raise ValueError("Prepare Chart Data Faild!")
def get_chart_values_by_db(self, db_name: str, chart_sql: str) :
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
return self.get_chart_values_by_conn(db_conn, chart_sql)

View File

@ -22,11 +22,13 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, db_name, user_input):
chat_mode = ChatScene.ChatWithDbExecute
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=db_name,
)
if not db_name:
raise ValueError(
@ -35,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat):
self.db_name = db_name
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
self.db_connect = self.database.session
self.top_k: int = 5
self.top_k: int = 200
def generate_input_values(self):
try:

View File

@ -0,0 +1,8 @@
import json
class DbDataLoader:
def get_table_view_by_conn(self, db_conn, chart_sql: str):
pass

View File

@ -25,6 +25,7 @@ class ChatWithDbQA(BaseChat):
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=db_name,
)
self.db_name = db_name
if db_name:

View File

@ -30,6 +30,7 @@ class ChatWithPlugin(BaseChat):
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=plugin_selector,
)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry

View File

@ -33,6 +33,7 @@ class ChatNewKnowledge(BaseChat):
chat_mode=ChatScene.ChatNewKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=knowledge_name,
)
self.knowledge_name = knowledge_name
vector_store_config = {

View File

@ -24,6 +24,7 @@ class InnerChatDBSummary(BaseChat):
chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=db_select,
)
self.db_input = db_select

View File

@ -30,6 +30,8 @@ class OnceConversation:
self.messages: List[BaseMessage] = []
self.start_date: str = ""
self.chat_order: int = 0
self.param_type: str = ""
self.param_value: str = ""
self.cost: int = 0
self.tokens: int = 0
@ -114,9 +116,13 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"cost": once.cost if once.cost else 0,
"tokens": once.tokens if once.tokens else 0,
"messages": messages_to_dict(once.messages),
"param_type": once.param_type,
"param_value": once.param_value
}
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
return [_conversation_to_dic(m) for m in conversations]
@ -128,6 +134,8 @@ def conversation_from_dict(once: dict) -> OnceConversation:
conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order"))
conversation.param_type = once.get("param_type", "")
conversation.param_value = once.get("param_value", "")
print(once.get("messages"))
conversation.messages = messages_from_dict(once.get("messages", []))
return conversation

View File

@ -51,19 +51,15 @@ def build_logger(logger_name, logger_filename):
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# stdout_logger = logging.getLogger("stdout")
# stdout_logger.setLevel(logging.INFO)
# sl_1 = StreamToLogger(stdout_logger, logging.INFO)
# sys.stdout = sl_1
#
# stderr_logger = logging.getLogger("stderr")
# stderr_logger.setLevel(logging.ERROR)
# sl = StreamToLogger(stderr_logger, logging.ERROR)
# sys.stderr = sl
# Add a file handler for all loggers
if handler is None:
@ -78,6 +74,12 @@ def build_logger(logger_name, logger_filename):
if isinstance(item, logging.Logger):
item.addHandler(handler)
logging.basicConfig(level=logging.INFO)
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
return logger