lint: fix

This commit is contained in:
csunny 2023-05-25 20:39:04 +08:00
parent 4040e1592a
commit e95696c23d
23 changed files with 277 additions and 173 deletions

View File

@ -1,21 +1,21 @@
import markdown2
import markdown2
import pandas as pd
def datas_to_table_html(data):
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
</style>"""
html_table = df.to_html(index=False, escape=False)
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
return html.replace("\n", " ")
def generate_markdown_table(data):
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n """
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n"""
# 获取表格列数
num_cols = len(data[0])
# 生成表头
@ -41,6 +41,7 @@ def generate_markdown_table(data):
return table
def generate_htm_table(data):
markdown_text = generate_markdown_table(data)
html_table = markdown2.markdown(markdown_text, extras=["tables"])
@ -53,4 +54,4 @@ if __name__ == "__main__":
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
print(table_style.replace("\n", " "))
print(table_style.replace("\n", " "))

View File

@ -1,8 +1,9 @@
from enum import auto, Enum
from typing import List, Any
class SeparatorStyle(Enum):
SINGLE ="###"
SINGLE = "###"
TWO = "</s>"
THREE = auto()
FOUR = auto()

View File

@ -30,16 +30,16 @@ class Database:
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
@ -119,7 +119,7 @@ class Database:
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> Database:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
@ -148,7 +148,7 @@ class Database:
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f'use `{db_name}`')
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
@ -159,13 +159,17 @@ class Database:
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
+ (self._inspector.get_view_names(schema=db_name) if self.view_support else [])
+ (
self._inspector.get_view_names(schema=db_name)
if self.view_support
else []
)
)
return session
def get_current_db_name(self, session) -> str:
return session.execute(text('SELECT DATABASE()')).scalar()
return session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
_sql = f"""
@ -201,7 +205,7 @@ class Database:
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
@ -214,7 +218,7 @@ class Database:
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
@ -303,6 +307,10 @@ class Database:
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(' show databases;'))
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]]
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]

View File

@ -47,7 +47,6 @@ class Config(metaclass=Singleton):
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
@ -104,8 +103,17 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### TODO Adapt to multiple types of libraries
self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) ,
engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True})
self.local_db = Database.from_uri(
"mysql+pymysql://"
+ self.LOCAL_DB_USER
+ ":"
+ self.LOCAL_DB_PASSWORD
+ "@"
+ self.LOCAL_DB_HOST
+ ":"
+ str(self.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")

View File

@ -3,6 +3,7 @@
import pymysql
class MySQLOperator:
"""Connect MySQL Database fetch MetaData For LLM Prompt
Args:
@ -31,8 +32,7 @@ class MySQLOperator:
results = cursor.fetchall()
return results
def run_sql(self, db_name:str, sql:str, fetch: str = "all"):
def run_sql(self, db_name: str, sql: str, fetch: str = "all"):
with self.conn.cursor() as cursor:
cursor.execute("USE " + db_name)
cursor.execute(sql)

View File

@ -44,6 +44,7 @@ class Conversation:
skip_next: bool = False
conv_id: Any = None
last_user_input: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from pydantic import BaseModel, Field, root_validator, validator,Extra
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
from typing import (
Any,
@ -17,13 +17,9 @@ from typing import (
from pilot.scene.message import OnceConversation
class BaseChatHistoryMemory(ABC):
def __init__(self):
self.conversations:List[OnceConversation] = []
self.conversations: List[OnceConversation] = []
@abstractmethod
def messages(self) -> List[OnceConversation]: # type: ignore
@ -33,8 +29,6 @@ class BaseChatHistoryMemory(ABC):
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""
@abstractmethod
def clear(self) -> None:
"""Clear session memory from the local file"""

View File

@ -5,31 +5,33 @@ import datetime
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pathlib import Path
from pilot.configs.config import Config
from pilot.scene.message import OnceConversation, conversation_from_dict,conversations_to_dict
from pilot.configs.config import Config
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
conversations_to_dict,
)
CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id:str):
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
date_string = now.strftime("%Y%m%d")
path: str = f"{CFG.message_dir}/{date_string}"
os.makedirs(path, exist_ok=True)
dir_path = Path(path)
self.file_path = Path(dir_path / f"{chat_session_id}.json")
self.file_path = Path(dir_path / f"{chat_session_id}.json")
if not self.file_path.exists():
self.file_path.touch()
self.file_path.write_text(json.dumps([]))
def messages(self) -> List[OnceConversation]:
items = json.loads(self.file_path.read_text())
history:List[OnceConversation] = []
history: List[OnceConversation] = []
for onece in items:
messages = conversation_from_dict(onece)
history.append(messages)
@ -38,8 +40,10 @@ class FileHistoryMemory(BaseChatHistoryMemory):
def append(self, once_message: OnceConversation) -> None:
historys = self.messages()
historys.append(once_message)
self.file_path.write_text(json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4), encoding="UTF-8")
self.file_path.write_text(
json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4),
encoding="UTF-8",
)
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))

View File

@ -41,16 +41,16 @@ class BaseOutputParser(ABC):
text = text.lower()
respObj = json.loads(text)
xx = respObj['response']
xx = xx.strip(b'\x00'.decode())
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
if respObj_ex['error_code'] == 0:
all_text = respObj_ex['text']
if respObj_ex["error_code"] == 0:
all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(sep)
last_index = -1
for i in range(len(tmpResp)):
if tmpResp[i].find('assistant:') != -1:
if tmpResp[i].find("assistant:") != -1:
last_index = i
ai_response = tmpResp[last_index]
ai_response = ai_response.replace("assistant:", "")
@ -60,9 +60,7 @@ class BaseOutputParser(ABC):
print("un_stream clear response:{}", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + respObj_ex['error_code']);
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_model_server_out(self, response) -> str:
"""

View File

@ -1,5 +1,3 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
@ -8,7 +6,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.scene.base_message import BaseMessage,HumanMessage,AIMessage, SystemMessage
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
def get_buffer_string(
@ -29,7 +27,6 @@ def get_buffer_string(
return "\n".join(string_messages)
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
@ -39,6 +36,7 @@ class PromptValue(BaseModel, ABC):
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
class ChatPromptValue(PromptValue):
messages: List[BaseMessage]
@ -48,4 +46,4 @@ class ChatPromptValue(PromptValue):
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return self.messages
return self.messages

View File

@ -3,8 +3,8 @@ from typing import Any, Callable, Dict, List, Optional
class PromptGenerator:
"""
generating custom prompt strings based on constraints
Compatible with AutoGpt Plugin;
generating custom prompt strings based on constraints
Compatible with AutoGpt Plugin;
"""
def __init__(self) -> None:
@ -22,8 +22,6 @@ class PromptGenerator:
self.role = "AI"
self.response_format = None
def add_command(
self,
command_label: str,
@ -51,4 +49,4 @@ class PromptGenerator:
"args": command_args,
"function": function,
}
self.commands.append(command)
self.commands.append(command)

View File

@ -8,6 +8,7 @@ from pilot.common.formatting import formatter
from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:
@ -32,22 +33,23 @@ class PromptTemplate(BaseModel, ABC):
"""A list of the names of the variables the prompt template expects."""
template_scene: str
template_define:str
template_define: str
"""this template define"""
template: str
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format:str
response_format: str
"""default use stream out"""
stream_out: bool = True
""""""
output_parser: BaseOutputParser = None
""""""
sep:str = SeparatorStyle.SINGLE.value
sep: str = SeparatorStyle.SINGLE.value
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@ -96,10 +98,8 @@ class PromptTemplate(BaseModel, ABC):
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
"""
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)

View File

@ -207,6 +207,7 @@ class BasePromptTemplate(BaseModel, ABC):
else:
raise ValueError(f"{save_path} must be json or yaml")
class StringPromptValue(PromptValue):
text: str
@ -219,7 +220,6 @@ class StringPromptValue(PromptValue):
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt should expose the format method, returning a prompt."""
@ -360,4 +360,4 @@ class PromptTemplate(StringPromptTemplate):
# For backwards compatibility.
Prompt = PromptTemplate
Prompt = PromptTemplate

View File

@ -1,8 +1,9 @@
from enum import Enum
class ChatScene(Enum):
ChatWithDb = "chat_with_db"
ChatExecution = "chat_execution"
ChatKnowledge = "chat_default_knowledge"
ChatNewKnowledge = "chat_new_knowledge"
ChatNormal = "chat_normal"
ChatNormal = "chat_normal"

View File

@ -20,7 +20,7 @@ from pilot.prompts.prompt_new import PromptTemplate
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
@ -32,8 +32,10 @@ from pilot.configs.config import Config
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}
CFG = Config()
class BaseChat( ABC):
chat_scene:str = None
class BaseChat(ABC):
chat_scene: str = None
llm_model: Any = None
temperature: float = 0.6
max_new_tokens: int = 1024
@ -42,17 +44,20 @@ class BaseChat( ABC):
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(self, chat_mode, chat_session_id, current_user_input):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input:str = current_user_input
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
### TODO
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0
@ -69,7 +74,6 @@ class BaseChat( ABC):
def chat_type(self) -> str:
raise NotImplementedError("Not supported for this chat type.")
def call(self):
pass
@ -88,7 +92,7 @@ class BaseChat( ABC):
"""
return self.memory.messages()
def generate(self, p)->str:
def generate(self, p) -> str:
"""
generate context for LLM input
Args:

View File

@ -15,6 +15,7 @@ from typing import (
from pydantic import BaseModel, Extra, Field, root_validator
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
@ -37,7 +38,6 @@ class BaseMessage(BaseModel):
"""Type of the message, used for serialization."""
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@ -49,7 +49,6 @@ class HumanMessage(BaseMessage):
return "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
@ -81,8 +80,6 @@ class SystemMessage(BaseMessage):
return "system"
class Generation(BaseModel):
"""Output of a single generation."""
@ -94,7 +91,6 @@ class Generation(BaseModel):
"""May include things like reason for finishing (e.g. in OpenAI)"""
class ChatGeneration(Generation):
"""Output of a single generation."""
@ -126,7 +122,6 @@ class LLMResult(BaseModel):
"""For arbitrary LLM provider specific output."""
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
@ -149,6 +144,5 @@ def _message_from_dict(message: dict) -> BaseMessage:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]

View File

@ -14,7 +14,13 @@ from sqlalchemy import (
)
from typing import Any, Iterable, List, Optional
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -25,7 +31,11 @@ from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.common.markdown_text import generate_markdown_table, generate_htm_table, datas_to_table_html
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_db.prompt import chat_db_prompt
from pilot.out_parser.base import BaseOutputParser
from pilot.scene.chat_db.out_parser import DbChatOutputParser
@ -39,8 +49,7 @@ class ChatWithDb(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, db_name, user_input):
"""
"""
""" """
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
@ -71,7 +80,9 @@ class ChatWithDb(BaseChat):
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
logger.info(f"There are already {len(self.history_message)} rounds of conversations!")
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
self.current_message.add_system_message(current_prompt)
@ -87,32 +98,56 @@ class ChatWithDb(BaseChat):
try:
### 走非流式的模型服务接口
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120)
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
ai_response_text = (
self.prompt_template.output_parser.parse_model_server_out(response)
)
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.database.run(self.db_connect, prompt_define_response.sql)
if hasattr(prompt_define_response, 'thoughts'):
if hasattr(prompt_define_response, "thoughts"):
if prompt_define_response.thoughts.get("speak"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("speak"),result))
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("speak"), result
)
)
elif prompt_define_response.thoughts.get("reasoning"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("reasoning"), result))
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("reasoning"), result
)
)
else:
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts, result))
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts, result
)
)
else:
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(prompt_define_response, result))
self.prompt_template.output_parser.parse_view_response(
prompt_define_response, result
)
)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """)
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
@ -121,10 +156,10 @@ class ChatWithDb(BaseChat):
# 单论对话只能有一次User 记录 和一次 AI 记录
# TODO 推理过程前端展示。。。
for message in self.current_message.messages:
if (isinstance(message, HumanMessage)):
if isinstance(message, HumanMessage):
ret[-1][-2] = message.content
# 是否展示推理过程
if (isinstance(message, ViewMessage)):
if isinstance(message, ViewMessage):
ret[-1][-1] = message.content
return ret
@ -132,34 +167,51 @@ class ChatWithDb(BaseChat):
# 暂时为了兼容前端
def current_ai_response(self) -> str:
for message in self.current_message.messages:
if message.type == 'view':
if message.type == "view":
return message.content
return None
def generate_llm_text(self) -> str:
text = self.prompt_template.template_define + self.prompt_template.sep
text = self.prompt_template.template_define + self.prompt_template.sep
### 线处理历史信息
if (len(self.history_message) > self.chat_retention_rounds):
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += first_message.type + ":" + first_message.content + self.prompt_template.sep
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage):
text += last_message.type + ":" + last_message.content + self.prompt_template.sep
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
if not isinstance(message, ViewMessage):
text += message.type + ":" + message.content + self.prompt_template.sep
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += now_message.type + ":" + now_message.content + self.prompt_template.sep
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return text

View File

@ -1,25 +1,24 @@
import json
import re
from abc import ABC, abstractmethod
from typing import (
Dict,
NamedTuple
)
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
class SqlAction(NamedTuple):
sql: str
thoughts: Dict
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep:str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out )
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_model_server_out(self, response) -> str:
return super().parse_model_server_out(response)
@ -31,20 +30,20 @@ class DbChatOutputParser(BaseOutputParser):
if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):]
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):]
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing")
json_pattern = r'{(.+?)}'
json_pattern = r"{(.+?)}"
m = re.search(json_pattern, cleaned_output)
if m:
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
raise ValueError("model server out not fllow the prompt!")
response = json.loads(cleaned_output)
sql, thoughts = response["sql"], response["thoughts"]

View File

@ -45,7 +45,7 @@ RESPONSE_FORMAT = {
"reasoning": "reasoning",
"speak": "thoughts summary to say to user",
},
"sql": "SQL Query to run"
"sql": "SQL Query to run",
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
@ -59,7 +59,9 @@ chat_db_prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})

View File

@ -4,8 +4,10 @@ from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.message import OnceConversation
from pilot.scene.base import ChatScene
class ChatWithPlugin(BaseChat):
chat_scene: str= ChatScene.ChatExecution.value
chat_scene: str = ChatScene.ChatExecution.value
def __init__(self, chat_mode, chat_session_id, current_user_input):
super().__init__(chat_mode, chat_session_id, current_user_input)
@ -23,4 +25,4 @@ class ChatWithPlugin(BaseChat):
@property
def chat_type(self) -> str:
return ChatScene.ChatExecution.value
return ChatScene.ChatExecution.value

View File

@ -1,19 +1,17 @@
from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton
from pilot.scene.chat_db.chat import ChatWithDb
from pilot.scene.chat_execution.chat import ChatWithPlugin
class ChatFactory(metaclass=Singleton):
class ChatFactory(metaclass=Singleton):
@staticmethod
def get_implementation(chat_mode, **kwargs):
chat_classes = BaseChat.__subclasses__()
implementation = None
for cls in chat_classes:
if(cls.chat_scene == chat_mode):
if cls.chat_scene == chat_mode:
implementation = cls(**kwargs)
if(implementation == None):
raise Exception('Invalid implementation name:' + chat_mode)
return implementation
if implementation == None:
raise Exception("Invalid implementation name:" + chat_mode)
return implementation

View File

@ -9,12 +9,20 @@ from typing import (
List,
)
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict
from pilot.scene.base_message import (
BaseMessage,
AIMessage,
HumanMessage,
SystemMessage,
ViewMessage,
messages_to_dict,
messages_from_dict,
)
class OnceConversation:
"""
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
"""
def __init__(self):
@ -26,7 +34,9 @@ class OnceConversation:
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
has_message = any(isinstance(instance, HumanMessage) for instance in self.messages)
has_message = any(
isinstance(instance, HumanMessage) for instance in self.messages
)
if has_message:
raise ValueError("Already Have Human message")
self.messages.append(HumanMessage(content=message))
@ -38,6 +48,7 @@ class OnceConversation:
raise ValueError("Already Have Ai message")
self.messages.append(AIMessage(content=message))
""" """
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
@ -50,7 +61,7 @@ class OnceConversation:
def set_start_time(self, datatime: datetime):
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
self.start_date = dt_str;
self.start_date = dt_str
def clear(self) -> None:
"""Remove all messages from the store"""
@ -71,7 +82,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"start_date": start_str,
"cost": once.cost if once.cost else 0,
"tokens": once.tokens if once.tokens else 0,
"messages": messages_to_dict(once.messages)
"messages": messages_to_dict(once.messages),
}
@ -81,10 +92,10 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation()
conversation.cost = once.get('cost', 0)
conversation.tokens = once.get('tokens', 0)
conversation.start_date = once.get('start_date', '')
conversation.chat_order = int(once.get('chat_order'))
print(once.get('messages'))
conversation.messages = messages_from_dict(once.get('messages', []))
conversation.cost = once.get("cost", 0)
conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order"))
print(once.get("messages"))
conversation.messages = messages_from_dict(once.get("messages", []))
return conversation

View File

@ -137,13 +137,15 @@ def load_demo(url_params, request: gr.Request):
unique_id = uuid.uuid1()
state.conv_id = str(unique_id)
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def get_conv_log_filename():
@ -203,30 +205,31 @@ def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
return ChatScene.ChatExecution
else:
return ChatScene.ChatNormal
return ChatScene.ChatNormal
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
def http_bot(
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
):
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
start_tstamp = time.time()
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
if ChatScene.ChatWithDb == scene:
logger.info("基于DB对话走新的模式")
chat_param ={
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
@ -242,7 +245,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
@ -251,7 +256,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
new_state.conv_id = uuid.uuid4().hex
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
@ -263,16 +267,25 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
)
print("向量数据库持久化地址: ", persist_dir)
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
local_persist=False,
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, 1)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
@ -285,7 +298,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
"stop": state.sep
if state.sep_style == SeparatorStyle.SINGLE
else state.sep2,
}
logger.info(f"Requert: \n{payload}")
@ -295,8 +310,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
try:
# Stream output
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
@ -309,12 +329,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
@ -573,7 +604,6 @@ def knowledge_embedding_store(vs_id, files):
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")