mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
多场景对架构一期0525
This commit is contained in:
@@ -2,10 +2,15 @@ import markdown2
|
||||
import pandas as pd
|
||||
|
||||
def datas_to_table_html(data):
|
||||
df = pd.DataFrame(data)
|
||||
table_style = """\n<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid blue;\n padding: 5px;\n text-align: left;\n }\n th {\n background-color: #f2f2f2;\n }\n</style>\n"""
|
||||
html_table = df.to_html(index=False, header=False, border = True)
|
||||
return table_style + html_table
|
||||
df = pd.DataFrame(data[1:], columns=data[0])
|
||||
table_style = """<style>
|
||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||
</style>"""
|
||||
html_table = df.to_html(index=False, escape=False)
|
||||
|
||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||
|
||||
return html.replace("\n", " ")
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +48,9 @@ def generate_htm_table(data):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
|
||||
# mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
|
||||
# print(generate_htm_table(mk_text))
|
||||
|
||||
print(generate_htm_table(mk_text))
|
||||
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
|
||||
|
||||
print(table_style.replace("\n", " "))
|
@@ -21,13 +21,23 @@ from pilot.prompts.base import PromptValue
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
||||
class BaseOutputParser(ABC):
|
||||
"""Class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
"""
|
||||
|
||||
def parse_model_nostream_resp(self, response, sep: str):
|
||||
|
||||
def __init__(self,sep:str, is_stream_out:bool):
|
||||
self.sep = sep
|
||||
self.is_stream_out = is_stream_out
|
||||
|
||||
|
||||
# TODO 后续和模型绑定
|
||||
def _parse_model_stream_resp(self, response, sep: str):
|
||||
pass
|
||||
|
||||
def _parse_model_nostream_resp(self, response, sep: str):
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
respObj = json.loads(text)
|
||||
@@ -52,35 +62,44 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex['error_code']);
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
A method which takes in a string (assumed output of language model )
|
||||
and parses it into some structure.
|
||||
|
||||
def parse_model_server_out(self, response)->str:
|
||||
"""
|
||||
parse the model server http response
|
||||
Args:
|
||||
text: output of language model
|
||||
response:
|
||||
|
||||
Returns:
|
||||
structured output
|
||||
|
||||
"""
|
||||
if self.is_stream_out:
|
||||
self._parse_model_nostream_resp(response, self.sep)
|
||||
else:
|
||||
### TODO
|
||||
self._parse_model_stream_resp(response, self.sep)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Optional method to parse the output of an LLM call with a prompt.
|
||||
|
||||
The prompt is largely provided in the event the OutputParser wants
|
||||
to retry or fix the output in some way, and needs information from
|
||||
the prompt to do so.
|
||||
|
||||
def parse_prompt_response(self, model_out_text)->T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
Args:
|
||||
completion: output of language model
|
||||
prompt: prompt value
|
||||
model_out_text:
|
||||
|
||||
Returns:
|
||||
structured output
|
||||
|
||||
"""
|
||||
return self.parse(completion)
|
||||
pass
|
||||
|
||||
|
||||
def parse_view_response(self, ai_text)->str:
|
||||
"""
|
||||
parse the ai response info to user view
|
||||
Args:
|
||||
text:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
|
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from pilot.common.formatting import formatter
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
@@ -36,7 +36,17 @@ class PromptTemplate(BaseModel, ABC):
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
response_format:str
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
output_parser: BaseOutputParser = None
|
||||
""""""
|
||||
sep:str = SeparatorStyle.SINGLE.value
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
|
@@ -32,16 +32,14 @@ from pilot.configs.config import Config
|
||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
CFG = Config()
|
||||
class BaseChat(ABC):
|
||||
class BaseChat( ABC):
|
||||
chat_scene: str = None
|
||||
memory: BaseChatHistoryMemory
|
||||
llm_model: Any = None
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
temperature: float = 0.6
|
||||
max_new_tokens: int = 1024
|
||||
# By default, keep the last two rounds of conversation records as the context
|
||||
chat_retention_rounds: int = 2
|
||||
sep = SeparatorStyle.SINGLE.value
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
arbitrary_types_allowed = True
|
||||
|
@@ -61,6 +61,17 @@ class AIMessage(BaseMessage):
|
||||
return "ai"
|
||||
|
||||
|
||||
class ViewMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the AI."""
|
||||
|
||||
example: bool = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "view"
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""Type of message that is a system message."""
|
||||
|
||||
@@ -132,6 +143,8 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "view":
|
||||
return ViewMessage(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected type: {_type}")
|
||||
|
||||
|
@@ -13,7 +13,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage
|
||||
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@@ -26,13 +26,15 @@ from pilot.utils import (
|
||||
)
|
||||
from pilot.common.markdown_text import generate_markdown_table,generate_htm_table,datas_to_table_html
|
||||
from pilot.scene.chat_db.prompt import chat_db_prompt
|
||||
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithDb(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatWithDb.value
|
||||
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
@@ -77,17 +79,22 @@ class ChatWithDb(BaseChat):
|
||||
"prompt": self.generate_llm_text(),
|
||||
"temperature": float(self.temperature),
|
||||
"max_new_tokens": int(self.max_new_tokens),
|
||||
"stop": self.sep_style.value,
|
||||
"stop": self.prompt_template.sep,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
try:
|
||||
### 走非流式的模型服务接口
|
||||
|
||||
# TODO - TEST
|
||||
# # TODO - TEST
|
||||
# response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120)
|
||||
# clear_response = self.prompt_template.output_parser.parse_model_nostream_resp(response, self.sep_style)
|
||||
# sql_action = self.prompt_template.output_parser.parse(clear_response)
|
||||
# ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
#
|
||||
# prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
# self.current_message.add_ai_message(json.dumps(prompt_define_response._asdict()))
|
||||
# result = self.database.run(self.db_connect, prompt_define_response.SQL)
|
||||
|
||||
|
||||
resp_test = {
|
||||
"SQL": "select * from users",
|
||||
"thoughts": {
|
||||
@@ -100,12 +107,10 @@ class ChatWithDb(BaseChat):
|
||||
}
|
||||
|
||||
sql_action = SqlAction(**resp_test)
|
||||
|
||||
# self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
|
||||
|
||||
self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
|
||||
result = self.database.run(self.db_connect, sql_action.SQL)
|
||||
|
||||
self.current_message.add_ai_message(f"{datas_to_table_html(result)}")
|
||||
self.current_message.add_view_message(self.prompt_template.output_parser.parse_view_response(result))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
@@ -118,17 +123,19 @@ class ChatWithDb(BaseChat):
|
||||
ret = []
|
||||
# 单论对话只能有一次User 记录 和一次 AI 记录
|
||||
# TODO 推理过程前端展示。。。
|
||||
for message in enumerate(self.current_message.messages):
|
||||
for message in self.current_message.messages:
|
||||
if (isinstance(message, HumanMessage)):
|
||||
ret[-1][-2] = message.content
|
||||
if (isinstance(message, AIMessage)):
|
||||
# 是否展示推理过程
|
||||
if (isinstance(message, ViewMessage)):
|
||||
ret[-1][-1] = message.content
|
||||
|
||||
return ret
|
||||
|
||||
# 暂时为了兼容前端
|
||||
def current_ai_response(self)->str:
|
||||
for message in self.current_message.messages:
|
||||
if message.type == 'ai':
|
||||
if message.type == 'view':
|
||||
return message.content
|
||||
return None
|
||||
|
||||
@@ -137,28 +144,31 @@ class ChatWithDb(BaseChat):
|
||||
text = ""
|
||||
### 线处理历史信息
|
||||
if (len(self.history_message) > self.chat_retention_rounds):
|
||||
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录
|
||||
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||
for first_message in self.history_message[0].messages:
|
||||
text += first_message.type + ":" + first_message.content + self.sep
|
||||
if not isinstance(first_message, ViewMessage):
|
||||
text += first_message.type + ":" + first_message.content + self.prompt_template.sep
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for last_message in self.history_message[-index:].messages:
|
||||
text += last_message.type + ":" + last_message.content + self.sep
|
||||
if not isinstance(last_message, ViewMessage):
|
||||
text += last_message.type + ":" + last_message.content + self.prompt_template.sep
|
||||
|
||||
else:
|
||||
### 直接历史记录拼接
|
||||
for conversation in self.history_message:
|
||||
for message in conversation.messages:
|
||||
text += message.type + ":" + message.content + self.sep
|
||||
if not isinstance(message, ViewMessage):
|
||||
text += message.type + ":" + message.content + self.prompt_template.sep
|
||||
|
||||
### current conversation
|
||||
for now_message in self.current_message.messages:
|
||||
text += now_message.type + ":" + now_message.content + self.sep
|
||||
text += now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatWithDb.value
|
||||
|
||||
|
0
pilot/scene/chat_db/example.py
Normal file
0
pilot/scene/chat_db/example.py
Normal file
@@ -11,8 +11,9 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
|
||||
|
||||
class SqlAction(NamedTuple):
|
||||
@@ -22,8 +23,15 @@ class SqlAction(NamedTuple):
|
||||
|
||||
class DbChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse(self, text: str) -> SqlAction:
|
||||
cleaned_output = text.rstrip()
|
||||
def __init__(self, sep:str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out )
|
||||
|
||||
|
||||
def parse_model_server_out(self, response) -> str:
|
||||
return super().parse_model_server_out(response)
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
cleaned_output = model_out_text.rstrip()
|
||||
if "```json" in cleaned_output:
|
||||
_, cleaned_output = cleaned_output.split("```json")
|
||||
if "```" in cleaned_output:
|
||||
@@ -40,6 +48,16 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
|
||||
return SqlAction(sql, thoughts)
|
||||
|
||||
def parse_view_response(self, data) -> str:
|
||||
### tool out data to table view
|
||||
df = pd.DataFrame(data[1:], columns=data[0])
|
||||
table_style = """<style>
|
||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||
</style>"""
|
||||
html_table = df.to_html(index=False, escape=False)
|
||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||
return html.replace("\n", " ")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "sql_chat"
|
||||
|
@@ -2,7 +2,8 @@ import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser,SqlAction
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@@ -21,6 +22,15 @@ You can order the results by a relevant column to return the most interesting ex
|
||||
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
|
||||
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
|
||||
"""
|
||||
|
||||
_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
|
||||
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
|
||||
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
|
||||
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
|
||||
|
||||
|
||||
"""
|
||||
|
||||
PROMPT_RESPONSE = """You should only respond in JSON format as following format:
|
||||
@@ -37,17 +47,18 @@ RESPONSE_FORMAT = {
|
||||
"SQL": "SQL Query to run"
|
||||
}
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
chat_db_prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDb.value,
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
|
||||
output_parser=DbChatOutputParser()
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE + PROMPT_SUFFIX,
|
||||
output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resp = chat_db_prompt.format(input="查询用户信息", table_info="user(a,b,c,d)", dialect="mysql", top_k=10)
|
||||
print(resp)
|
||||
|
31
pilot/scene/chat_normal/prompt.py
Normal file
31
pilot/scene/chat_normal/prompt.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import builtins
|
||||
|
||||
|
||||
def stream_write_and_read(lst):
|
||||
# 对lst使用yield from进行可迭代对象的扁平化
|
||||
yield from lst
|
||||
while True:
|
||||
val = yield
|
||||
lst.append(val)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建一个空列表
|
||||
my_list = []
|
||||
|
||||
# 使用生成器写入数据
|
||||
stream_writer = stream_write_and_read(my_list)
|
||||
next(stream_writer)
|
||||
stream_writer.send(10)
|
||||
print(1)
|
||||
stream_writer.send(20)
|
||||
print(2)
|
||||
stream_writer.send(30)
|
||||
print(3)
|
||||
|
||||
# 使用生成器读取数据
|
||||
stream_reader = stream_write_and_read(my_list)
|
||||
next(stream_reader)
|
||||
print(stream_reader.send(None))
|
||||
print(stream_reader.send(None))
|
||||
print(stream_reader.send(None))
|
@@ -9,8 +9,7 @@ from typing import (
|
||||
List,
|
||||
)
|
||||
|
||||
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, messages_to_dict, \
|
||||
messages_from_dict
|
||||
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict
|
||||
|
||||
|
||||
class OnceConversation:
|
||||
@@ -27,12 +26,23 @@ class OnceConversation:
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Add a user message to the store"""
|
||||
has_message = any(isinstance(instance, HumanMessage) for instance in self.messages)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Human message")
|
||||
self.messages.append(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Ai message")
|
||||
self.messages.append(AIMessage(content=message))
|
||||
""" """
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
self.messages.append(ViewMessage(content=message))
|
||||
""" """
|
||||
|
||||
def add_system_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
@@ -218,7 +218,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
# state.append_message(state.roles[1], chat.current_ai_response())
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
Reference in New Issue
Block a user