Merge branch 'dev' into llm_fxp

This commit is contained in:
csunny
2023-06-01 14:39:33 +08:00
66 changed files with 1780 additions and 1336 deletions

View File

@@ -55,54 +55,6 @@ def fix_and_parse_json(
logger.error("参数解析错误", e)
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
"""Fix the given JSON string to make it parseable and fully compliant with two techniques.
Args:
json_string (str): The JSON string to fix.
Returns:
str: The fixed JSON string.
"""
assistant_reply = assistant_reply.strip()
if assistant_reply.startswith("```json"):
assistant_reply = assistant_reply[7:]
if assistant_reply.endswith("```"):
assistant_reply = assistant_reply[:-3]
try:
return json.loads(assistant_reply) # just check the validity
except json.JSONDecodeError as e: # noqa: E722
print(f"JSONDecodeError: {e}")
pass
if assistant_reply.startswith("json "):
assistant_reply = assistant_reply[5:]
assistant_reply = assistant_reply.strip()
try:
return json.loads(assistant_reply) # just check the validity
except json.JSONDecodeError: # noqa: E722
pass
# Parse and print Assistant response
assistant_reply_json = fix_and_parse_json(assistant_reply)
logger.debug("Assistant reply JSON: %s", str(assistant_reply_json))
if assistant_reply_json == {}:
assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets(
assistant_reply
)
logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json))
if assistant_reply_json != {}:
return assistant_reply_json
logger.error(
"Error: The following AI output couldn't be converted to a JSON:\n",
assistant_reply,
)
if CFG.speak_mode:
say_text("I have received an invalid JSON response from the OpenAI API.")
return {}
def correct_json(json_to_load: str) -> str:

View File

@@ -4,10 +4,9 @@
import json
from typing import Dict
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
from pilot.commands.exception_not_commands import NotCommands
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from pilot.prompts.generator import PluginPromptGenerator
from pilot.speech import say_text
@@ -24,8 +23,8 @@ def _resolve_pathlike_command_args(command_args):
def execute_ai_response_json(
prompt: PromptGenerator,
ai_response: str,
prompt: PluginPromptGenerator,
ai_response,
user_input: str = None,
) -> str:
"""
@@ -39,11 +38,8 @@ def execute_ai_response_json(
"""
cfg = Config()
try:
assistant_reply_json = fix_json_using_multiple_techniques(ai_response)
except (json.JSONDecodeError, ValueError, AttributeError) as e:
raise NotCommands("非可执行命令结构")
command_name, arguments = get_command(assistant_reply_json)
command_name, arguments = get_command(ai_response)
if cfg.speak_mode:
say_text(f"I want to execute {command_name}")
@@ -71,7 +67,7 @@ def execute_ai_response_json(
def execute_command(
command_name: str,
arguments,
prompt: PromptGenerator,
prompt: PluginPromptGenerator,
):
"""Execute the command and return the result

View File

@@ -1,29 +0,0 @@
from typing import Optional
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from pilot.prompts.prompt import build_default_prompt_generator
class CommandsLoad:
"""
Load Plugins Commands Info , help build system prompt!
"""
def __init__(self) -> None:
self.command_registry = None
def getCommandInfos(
self, prompt_generator: Optional[PromptGenerator] = None
) -> str:
cfg = Config()
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
for plugin in cfg.plugins:
if not plugin.can_handle_post_prompt():
continue
prompt_generator = plugin.post_prompt(prompt_generator)
self.prompt_generator = prompt_generator
command_infos = ""
command_infos += f"\n\n{prompt_generator.commands()}"
return command_infos

View File

@@ -277,6 +277,7 @@ class Database:
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("sql run:" + command)
cursor = session.execute(text(command))
if cursor.returns_rows:
if fetch == "all":

View File

@@ -1,167 +0,0 @@
# sourcery skip: do-not-use-staticmethod
"""
A module that contains the AIConfig class object that contains the configuration
"""
from __future__ import annotations
import os
import platform
from pathlib import Path
from typing import Optional
import distro
import yaml
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from pilot.prompts.prompt import build_default_prompt_generator
# Soon this will go in a folder where it remembers more stuff about the run(s)
SAVE_FILE = str(Path(os.getcwd()) / "ai_settings.yaml")
class AIConfig:
"""
A class object that contains the configuration information for the AI
Attributes:
ai_name (str): The name of the AI.
ai_role (str): The description of the AI's role.
ai_goals (list): The list of objectives the AI is supposed to complete.
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
"""
def __init__(
self,
ai_name: str = "",
ai_role: str = "",
ai_goals: list | None = None,
api_budget: float = 0.0,
) -> None:
"""
Initialize a class instance
Parameters:
ai_name (str): The name of the AI.
ai_role (str): The description of the AI's role.
ai_goals (list): The list of objectives the AI is supposed to complete.
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
Returns:
None
"""
if ai_goals is None:
ai_goals = []
self.ai_name = ai_name
self.ai_role = ai_role
self.ai_goals = ai_goals
self.api_budget = api_budget
self.prompt_generator = None
self.command_registry = None
@staticmethod
def load(config_file: str = SAVE_FILE) -> "AIConfig":
"""
Returns class object with parameters (ai_name, ai_role, ai_goals, api_budget) loaded from
yaml file if yaml file exists,
else returns class with no parameters.
Parameters:
config_file (int): The path to the config yaml file.
DEFAULT: "../ai_settings.yaml"
Returns:
cls (object): An instance of given cls object
"""
try:
with open(config_file, encoding="utf-8") as file:
config_params = yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
config_params = {}
ai_name = config_params.get("ai_name", "")
ai_role = config_params.get("ai_role", "")
ai_goals = [
str(goal).strip("{}").replace("'", "").replace('"', "")
if isinstance(goal, dict)
else str(goal)
for goal in config_params.get("ai_goals", [])
]
api_budget = config_params.get("api_budget", 0.0)
# type is Type[AIConfig]
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
def save(self, config_file: str = SAVE_FILE) -> None:
"""
Saves the class parameters to the specified file yaml file path as a yaml file.
Parameters:
config_file(str): The path to the config yaml file.
DEFAULT: "../ai_settings.yaml"
Returns:
None
"""
config = {
"ai_name": self.ai_name,
"ai_role": self.ai_role,
"ai_goals": self.ai_goals,
"api_budget": self.api_budget,
}
with open(config_file, "w", encoding="utf-8") as file:
yaml.dump(config, file, allow_unicode=True)
def construct_full_prompt(
self, prompt_generator: Optional[PromptGenerator] = None
) -> str:
"""
Returns a prompt to the user with the class information in an organized fashion.
Parameters:
None
Returns:
full_prompt (str): A string containing the initial prompt for the user
including the ai_name, ai_role, ai_goals, and api_budget.
"""
prompt_start = (
"Your decisions must always be made independently without"
" seeking user assistance. Play to your strengths as an LLM and pursue"
" simple strategies with no legal complications."
""
)
cfg = Config()
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = self.ai_goals
prompt_generator.name = self.ai_name
prompt_generator.role = self.ai_role
prompt_generator.command_registry = self.command_registry
for plugin in cfg.plugins:
if not plugin.can_handle_post_prompt():
continue
prompt_generator = plugin.post_prompt(prompt_generator)
if cfg.execute_local_commands:
# add OS info to prompt
os_name = platform.system()
os_info = (
platform.platform(terse=True)
if os_name != "Linux"
else distro.name(pretty=True)
)
prompt_start += f"\nThe OS you are running on is: {os_info}"
# Construct full prompt
full_prompt = f"You are {prompt_generator.name}, {prompt_generator.role}\n{prompt_start}\n\nGOALS:\n\n"
for i, goal in enumerate(self.ai_goals):
full_prompt += f"{i+1}. {goal}\n"
if self.api_budget > 0.0:
full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
return full_prompt

View File

@@ -2,7 +2,34 @@
# -*- coding:utf-8 -*-
"""We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod
from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional
class BaseConnection:
pass
class BaseConnect(BaseModel, ABC):
type
driver: str
def get_session(self, db_name: str):
pass
def get_table_names(self) -> Iterable[str]:
pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_database_list(self):
pass
def run(self, session, command: str, fetch: str = "all") -> List:
pass

View File

@@ -1,64 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pymysql
class MySQLOperator:
"""Connect MySQL Database fetch MetaData For LLM Prompt
Args:
Usage:
"""
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
def __init__(self, user, password, host="localhost", port=3306) -> None:
self.conn = pymysql.connect(
host=host,
user=user,
port=port,
passwd=password,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
)
def get_schema(self, schema_name):
with self.conn.cursor() as cursor:
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME;
"""
cursor.execute(_sql)
results = cursor.fetchall()
return results
def run_sql(self, db_name: str, sql: str, fetch: str = "all"):
with self.conn.cursor() as cursor:
cursor.execute("USE " + db_name)
cursor.execute(sql)
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return str(result)
def get_index(self, schema_name):
pass
def get_db_list(self):
with self.conn.cursor() as cursor:
_sql = """
show databases;
"""
cursor.execute(_sql)
results = cursor.fetchall()
dbs = [
d["Database"] for d in results if d["Database"] not in self.default_db
]
return dbs
def get_meta(self, schema_name):
pass

View File

View File

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pymysql
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData For LLM Prompt
Args:
Usage:
"""
type:str = "MySQL"
connect_url = "mysql+pymysql://"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@@ -0,0 +1,318 @@
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
import sqlalchemy
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import CursorResult, Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
from pilot.connections.base import BaseConnect
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
f' Columns: {str(index["column_names"])}'
)
class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
self._db_sessions = Session
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
# self._all_tables = set(
# self._inspector.get_table_names(schema=schema)
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
# )
# self._include_tables = set(include_tables) if include_tables else set()
# if self._include_tables:
# missing_tables = self._include_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"include_tables {missing_tables} not found in database"
# )
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
# if self._ignore_tables:
# missing_tables = self._ignore_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"ignore_tables {missing_tables} not found in database"
# )
# usable_tables = self.get_usable_table_names()
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
# if not isinstance(sample_rows_in_table_info, int):
# raise TypeError("sample_rows_in_table_info must be an integer")
#
# self._sample_rows_in_table_info = sample_rows_in_table_info
# self._indexes_in_table_info = indexes_in_table_info
#
# self._custom_table_info = custom_table_info
# if self._custom_table_info:
# if not isinstance(self._custom_table_info, dict):
# raise TypeError(
# "table_info must be a dictionary with table names as keys and the "
# "desired table info as values"
# )
# # only keep the tables that are also present in the database
# intersection = set(self._custom_table_info).intersection(self._all_tables)
# self._custom_table_info = dict(
# (table, self._custom_table_info[table])
# for table in self._custom_table_info
# if table in intersection
# )
# self._metadata = metadata or MetaData()
# # # including view support if view_support = true
# self._metadata.reflect(
# views=view_support,
# bind=self._engine,
# only=list(self._usable_tables),
# schema=self._schema,
# )
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
def get_session(self, db_name: str):
session = self._db_sessions()
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
self._metadata.reflect(bind=self._engine, schema=db_name)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
+ (
self._inspector.get_view_names(schema=db_name)
if self.view_support
else []
)
)
return session
def get_current_db_name(self, session) -> str:
return session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
"""
cursor = session.execute(text(_sql))
results = cursor.fetchall()
return results
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info:
table_info += f"\n{self._get_sample_rows(table)}\n"
if has_extra_info:
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result: CursorResult = connection.execute(command)
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _get_table_indexes(self, table: Table) -> str:
indexes = self._inspector.get_indexes(table.name)
indexes_formatted = "\n".join(map(_format_index, indexes))
return f"Table Indexes:\n{indexes_formatted}"
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
cursor = session.execute(text(command))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
return result
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(session, command, fetch)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]

View File

@@ -105,18 +105,14 @@ class Conversation:
}
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql import MySQLOperator
mo = MySQLOperator(**(DB_SETTINGS))
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"Database {dbname} Schema information as follows: {message}\n"
conv_default = Conversation(
system = None,
roles=("human", "ai"),
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_one_shot = Conversation(
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
@@ -261,7 +257,15 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
# question:
# {question}
# """
default_conversation = conv_one_shot
default_conversation = conv_default
chat_mode_title = {
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
"chat_use_plugin": get_lang_text("chat_use_plugin"),
"knowledge_qa": get_lang_text("knowledge_qa"),
}
conversation_sql_mode = {
"auto_execute_ai_response": get_lang_text("sql_generate_mode_direct"),
@@ -274,15 +278,11 @@ conversation_types = {
"knowledge_qa_type_default_knowledge_base_dialogue"
),
"custome": get_lang_text("knowledge_qa_type_add_knowledge_base_dialogue"),
"auto_execute_plugin": get_lang_text("dialogue_use_plugin"),
"url": get_lang_text("knowledge_qa_type_url_knowledge_dialogue"),
}
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1,
"auto_dbgpt_one_shot": auto_dbgpt_one_shot,
}
if __name__ == "__main__":
message = gen_sqlgen_conversation("dbgpt")
print(message)
}

View File

@@ -14,17 +14,22 @@ lang_dicts = {
"knowledge_qa_type_llm_native_dialogue": "LLM原生对话",
"knowledge_qa_type_default_knowledge_base_dialogue": "默认知识库对话",
"knowledge_qa_type_add_knowledge_base_dialogue": "新增知识库对话",
"dialogue_use_plugin": "对话使用插件",
"knowledge_qa_type_url_knowledge_dialogue": "URL网页知识对话",
"create_knowledge_base": "新建知识库",
"sql_schema_info": "数据库{}的Schema信息如下: {}\n",
"current_dialogue_mode": "当前对话模式",
"database_smart_assistant": "数据库智能助手",
"sql_vs_setting": "自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力",
"knowledge_qa": "知识问答",
"chat_use_plugin": "插件模式",
"dialogue_use_plugin": "对话使用插件",
"select_plugin": "选择插件",
"configure_knowledge_base": "配置知识库",
"new_klg_name": "新知识库名称",
"url_input_label": "输入网页地址",
"add_as_new_klg": "添加为新知识库",
"add_file_to_klg": "向知识库中添加文件",
"upload_file": "上传文件",
"add_file": "添加文件",
"upload_and_load_to_klg": "上传并加载到知识库",
@@ -47,14 +52,18 @@ lang_dicts = {
"knowledge_qa_type_llm_native_dialogue": "LLM native dialogue",
"knowledge_qa_type_default_knowledge_base_dialogue": "Default documents",
"knowledge_qa_type_add_knowledge_base_dialogue": "Added documents",
"knowledge_qa_type_url_knowledge_dialogue": "Chat with url",
"dialogue_use_plugin": "Dialogue Extension",
"create_knowledge_base": "Create Knowledge Base",
"sql_schema_info": "the schema information of database {}: {}\n",
"current_dialogue_mode": "Current dialogue mode",
"database_smart_assistant": "Database smart assistant",
"sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn",
"chat_use_plugin": "Plugin Mode",
"select_plugin": "Select Plugin",
"knowledge_qa": "Documents QA",
"configure_knowledge_base": "Configure Documents",
"url_input_label": "Please input url",
"new_klg_name": "New document name",
"add_as_new_klg": "Add as new documents",
"add_file_to_klg": "Add file to documents",

View File

@@ -0,0 +1,33 @@
from typing import List
import json
import os
import datetime
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pathlib import Path
from pilot.configs.config import Config
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
conversations_to_dict,
)
CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = {}
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []})
def messages(self) -> List[OnceConversation]:
return self.histroies_map.get(self.chat_seesion_id)
def append(self, once_message: OnceConversation) -> None:
self.histroies_map.get(self.chat_seesion_id).append(once_message)
def clear(self) -> None:
self.histroies_map.pop(self.chat_seesion_id)

View File

@@ -21,22 +21,46 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
}
messages = prompt.split(stop)
# Add history conversation
for i in range(1, len(messages) - 2, 2):
history.append(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
)
history.append(
{
"role": "system",
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
}
)
for message in messages:
if len(message) <= 0:
continue
if "human:" in message:
history.append(
{"role": "user", "content": message.split("human:")[1]},
)
elif "system:" in message:
history.append(
{
"role": "system",
"content": message.split("system:")[1],
}
)
elif "ai:" in message:
history.append(
{
"role": "ai",
"content": message.split("ai:")[1],
}
)
else:
history.append(
{
"role": "system",
"content": message,
}
)
# 把最后一个用户的信息移动到末尾
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
# Add user query
query = messages[-2].split(ROLE_USER + ":")[1]
history.append({"role": "user", "content": query})
payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history,

View File

@@ -13,12 +13,17 @@ from typing import (
TypeVar,
Union,
)
from pilot.utils import build_logger
import re
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.prompts.base import PromptValue
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
T = TypeVar("T")
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
CFG = Config()
class BaseOutputParser(ABC):
@@ -31,9 +36,39 @@ class BaseOutputParser(ABC):
self.sep = sep
self.is_stream_out = is_stream_out
def __post_process_code(self, code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
# TODO 后续和模型绑定
def _parse_model_stream_resp(self, response, sep: str):
pass
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"].strip()
else:
output = data["text"].strip()
output = self.__post_process_code(output)
yield output
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
yield output
def _parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip()
@@ -57,12 +92,12 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\n", "")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
print("un_stream clear response:{}", ai_response)
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_model_server_out(self, response) -> str:
def parse_model_server_out(self, response):
"""
parse the model server http response
Args:
@@ -85,7 +120,28 @@ class BaseOutputParser(ABC):
Returns:
"""
pass
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing")
json_pattern = r"{(.+?)}"
m = re.search(json_pattern, cleaned_output)
if m:
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
return cleaned_output
def parse_view_response(self, ai_text) -> str:
"""
@@ -96,7 +152,7 @@ class BaseOutputParser(ABC):
Returns:
"""
pass
return ai_text
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""

View File

@@ -1,143 +0,0 @@
import platform
from typing import Optional
import distro
import yaml
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
from pilot.prompts.prompt import (
DEFAULT_PROMPT_OHTER,
DEFAULT_TRIGGERING_PROMPT,
build_default_prompt_generator,
)
class AutoModePrompt:
""" """
def __init__(
self,
ai_goals: list | None = None,
) -> None:
"""
Initialize a class instance
Parameters:
ai_name (str): The name of the AI.
ai_role (str): The description of the AI's role.
ai_goals (list): The list of objectives the AI is supposed to complete.
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
Returns:
None
"""
if ai_goals is None:
ai_goals = []
self.ai_goals = ai_goals
self.prompt_generator = None
self.command_registry = None
def construct_follow_up_prompt(
self,
user_input: [str],
last_auto_return: str = None,
prompt_generator: Optional[PromptGenerator] = None,
) -> str:
"""
Build complete prompt information based on subsequent dialogue information entered by the user
Args:
self:
prompt_generator:
Returns:
"""
prompt_start = DEFAULT_PROMPT_OHTER
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = user_input
prompt_generator.command_registry = self.command_registry
# 加载插件中可用命令
cfg = Config()
for plugin in cfg.plugins:
if not plugin.can_handle_post_prompt():
continue
prompt_generator = plugin.post_prompt(prompt_generator)
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
if not self.ai_goals:
self.ai_goals = user_input
for i, goal in enumerate(self.ai_goals):
full_prompt += (
f"{i+1}.According to the provided Schema information, {goal}\n"
)
# if last_auto_return == None:
# full_prompt += f"{cfg.last_plugin_return}\n\n"
# else:
# full_prompt += f"{last_auto_return}\n\n"
full_prompt += f"Constraints:\n\n{DEFAULT_TRIGGERING_PROMPT}\n"
full_prompt += """Based on the above definition, answer the current goal and ensure that the response meets both the current constraints and the above definition and constraints"""
self.prompt_generator = prompt_generator
return full_prompt
def construct_first_prompt(
self,
fisrt_message: [str] = [],
db_schemes: str = None,
prompt_generator: Optional[PromptGenerator] = None,
) -> str:
"""
Build complete prompt information based on the initial dialogue information entered by the user
Args:
self:
prompt_generator:
Returns:
"""
prompt_start = (
"Your decisions must always be made independently without"
" seeking user assistance. Play to your strengths as an LLM and pursue"
" simple strategies with no legal complications."
""
)
if prompt_generator is None:
prompt_generator = build_default_prompt_generator()
prompt_generator.goals = fisrt_message
prompt_generator.command_registry = self.command_registry
# 加载插件中可用命令
cfg = Config()
for plugin in cfg.plugins:
if not plugin.can_handle_post_prompt():
continue
prompt_generator = plugin.post_prompt(prompt_generator)
if cfg.execute_local_commands:
# add OS info to prompt
os_name = platform.system()
os_info = (
platform.platform(terse=True)
if os_name != "Linux"
else distro.name(pretty=True)
)
prompt_start += f"\nThe OS you are running on is: {os_info}"
# Construct full prompt
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
if not self.ai_goals:
self.ai_goals = fisrt_message
for i, goal in enumerate(self.ai_goals):
full_prompt += (
f"{i+1}.According to the provided Schema information,{goal}\n"
)
if db_schemes:
full_prompt += f"\nSchema:\n\n"
full_prompt += f"{db_schemes}"
# if self.api_budget > 0.0:
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
self.prompt_generator = prompt_generator
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
return full_prompt

View File

@@ -3,7 +3,7 @@ import json
from typing import Any, Callable, Dict, List, Optional
class PromptGenerator:
class PluginPromptGenerator:
"""
A class for generating custom prompt strings based on constraints, commands,
resources, and performance evaluations.
@@ -133,6 +133,11 @@ class PromptGenerator:
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_commands_string(self)->str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
def generate_prompt_string(self) -> str:
"""
Generate a prompt string based on the constraints, commands, resources,

View File

@@ -1,73 +0,0 @@
from pilot.configs.config import Config
from pilot.prompts.generator import PromptGenerator
CFG = Config()
DEFAULT_TRIGGERING_PROMPT = (
"Determine which next command to use, and respond using the format specified above"
)
DEFAULT_PROMPT_OHTER = "Previous response was excellent. Please response according to the requirements based on the new goal"
def build_default_prompt_generator() -> PromptGenerator:
"""
This function generates a prompt string that includes various constraints,
commands, resources, and performance evaluations.
Returns:
str: The generated prompt string.
"""
# Initialize the PromptGenerator object
prompt_generator = PromptGenerator()
# Add constraints to the PromptGenerator object
# prompt_generator.add_constraint(
# "~4000 word limit for short term memory. Your short term memory is short, so"
# " immediately save important information to files."
# )
prompt_generator.add_constraint(
"If you are unsure how you previously did something or want to recall past"
" events, thinking about similar events will help you remember."
)
# prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint("Only output one correct JSON response at a time")
prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"'
)
prompt_generator.add_constraint(
"If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition"
)
prompt_generator.add_constraint(
"The generated command args need to comply with the definition of the command"
)
# Add resources to the PromptGenerator object
# prompt_generator.add_resource(
# "Internet access for searches and information gathering."
# )
# prompt_generator.add_resource("Long Term memory management.")
# prompt_generator.add_resource(
# "DB-GPT powered Agents for delegation of simple tasks."
# )
# prompt_generator.add_resource("File output.")
# Add performance evaluations to the PromptGenerator object
prompt_generator.add_performance_evaluation(
"Continuously review and analyze your actions to ensure you are performing to"
" the best of your abilities."
)
prompt_generator.add_performance_evaluation(
"Constructively self-criticize your big-picture behavior constantly."
)
prompt_generator.add_performance_evaluation(
"Reflect on past decisions and strategies to refine your approach."
)
# prompt_generator.add_performance_evaluation(
# "Every command has a cost, so be smart and efficient. Aim to complete tasks in"
# " the least number of steps."
# )
# prompt_generator.add_performance_evaluation("Write all code to a file.")
return prompt_generator

View File

@@ -1,52 +0,0 @@
from typing import Any, Callable, Dict, List, Optional
class PromptGenerator:
"""
generating custom prompt strings based on constraints
Compatible with AutoGpt Plugin;
"""
def __init__(self) -> None:
"""
Initialize the PromptGenerator object with empty lists of constraints,
commands, resources, and performance evaluations.
"""
self.constraints = []
self.commands = []
self.resources = []
self.performance_evaluation = []
self.goals = []
self.command_registry = None
self.name = "Bob"
self.role = "AI"
self.response_format = None
def add_command(
self,
command_label: str,
command_name: str,
args=None,
function: Optional[Callable] = None,
) -> None:
"""
Add a command to the commands list with a label, name, and optional arguments.
GB-GPT and Auto-GPT plugin registration command.
Args:
command_label (str): The label of the command.
command_name (str): The name of the command.
args (dict, optional): A dictionary containing argument names and their
values. Defaults to None.
function (callable, optional): A callable function to be called when
the command is executed. Defaults to None.
"""
if args is None:
args = {}
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
command = {
"label": command_label,
"name": command_name,
"args": command_args,
"function": function,
}
self.commands.append(command)

View File

@@ -31,15 +31,15 @@ DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
class PromptTemplate(BaseModel, ABC):
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template_scene: str
template_scene: Optional[str]
template_define: str
template_define: Optional[str]
"""this template define"""
template: str
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: str
response_format: Optional[str]
"""default use stream out"""
stream_out: bool = True
""""""
@@ -57,52 +57,12 @@ class PromptTemplate(BaseModel, ABC):
"""Return the prompt type key."""
return "prompt"
def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""
Generate a formatted string representation of a command.
Args:
command (dict): A dictionary containing command information.
Returns:
str: The formatted command string.
"""
args_string = ", ".join(
f'"{key}": "{value}"' for key, value in command["args"].items()
)
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = []
if self.command_registry:
command_strings += [
str(item)
for item in self.command_registry.commands.values()
if item.enabled
]
# terminate command is added manually
command_strings += [self._generate_command_string(item) for item in items]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
if self.template:
if self.response_format:
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
def add_goals(self, goal: str) -> None:
self.goals.append(goal)

View File

@@ -2,8 +2,10 @@ from enum import Enum
class ChatScene(Enum):
ChatWithDb = "chat_with_db"
ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa"
ChatExecution = "chat_execution"
ChatKnowledge = "chat_default_knowledge"
ChatNewKnowledge = "chat_new_knowledge"
ChatUrlKnowledge = "chat_url_knowledge"
ChatNormal = "chat_normal"

View File

@@ -1,4 +1,8 @@
import time
from abc import ABC, abstractmethod
import datetime
import traceback
import json
from pydantic import BaseModel, Field, root_validator, validator, Extra
from typing import (
Any,
@@ -19,14 +23,20 @@ from pilot.scene.message import OnceConversation
from pilot.prompts.prompt_new import PromptTemplate
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.common.schema import SeparatorStyle
from pilot.scene.base import ChatScene
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
)
from pilot.configs.config import Config
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
@@ -47,20 +57,23 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(self, chat_mode, chat_session_id, current_user_input):
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
### TODO
### can configurable storage methods
# self.memory = MemHistoryMemory(chat_session_id)
## TEST
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0
self.temperature = temperature
self.max_new_tokens = max_new_tokens
### load chat_session_id's chat historys
self._load_history(self.chat_session_id)
@@ -74,14 +87,183 @@ class BaseChat(ABC):
def chat_type(self) -> str:
raise NotImplementedError("Not supported for this chat type.")
@abstractmethod
def generate_input_values(self):
pass
@abstractmethod
def do_with_prompt_response(self, prompt_response):
pass
def __call_base(self):
input_values = self.generate_input_values()
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now()
# TODO
self.current_message.tokens = 0
current_prompt = None
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
if current_prompt:
self.current_message.add_system_message(current_prompt)
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
"temperature": float(self.temperature),
"max_new_tokens": int(self.max_new_tokens),
"stop": self.prompt_template.sep,
}
return payload
def stream_call(self):
payload = self.__call_base()
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
show_info = ""
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
timeout=120,
)
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
for resp_text_trunck in ai_response_text:
show_info = resp_text_trunck
yield resp_text_trunck + ""
self.current_message.add_ai_message(show_info)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
def nostream_call(self):
payload = self.__call_base()
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
### 走非流式的模型服务接口
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_server_out(response)
)
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
result = self.do_with_prompt_response(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict):
if "speak" in prompt_define_response.thoughts:
speak_to_user = prompt_define_response.thoughts.get("speak")
else:
speak_to_user = str(prompt_define_response.thoughts)
else:
if hasattr(prompt_define_response.thoughts, "speak"):
speak_to_user = prompt_define_response.thoughts.get("speak")
elif hasattr(prompt_define_response.thoughts, "reasoning"):
speak_to_user = prompt_define_response.thoughts.get("reasoning")
else:
speak_to_user = prompt_define_response.thoughts
else:
speak_to_user = prompt_define_response
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
self.current_message.add_view_message(view_message)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
return self.current_ai_response()
def call(self):
pass
if self.prompt_template.stream_out:
yield self.stream_call()
else:
return self.nostream_call()
def chat_show(self):
pass
def generate_llm_text(self) -> str:
text = ""
if self.prompt_template.template_define:
text = self.prompt_template.template_define + self.prompt_template.sep
### 处理历史信息
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage):
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
if not isinstance(message, ViewMessage):
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return text
# 暂时为了兼容前端
def current_ai_response(self) -> str:
pass
for message in self.current_message.messages:
if message.type == "view":
return message.content
return None
def _load_history(self, session_id: str) -> List[OnceConversation]:
"""
@@ -102,3 +284,4 @@ class BaseChat(ABC):
"""
pass

View File

@@ -0,0 +1,57 @@
import json
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
CFG = Config()
class ChatWithDbAutoExecute(BaseChat):
chat_scene: str = ChatScene.ChatWithDbExecute.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbExecute,
chat_session_id=chat_session_id,
current_user_input=user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!")
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
}
return input_values
def do_with_prompt_response(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql)
if __name__ == "__main__":
ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}"
ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '')
print(ss)
json.loads(ss)

View File

@@ -20,32 +20,11 @@ class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_model_server_out(self, response) -> str:
return super().parse_model_server_out(response)
def parse_prompt_response(self, model_out_text):
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing")
json_pattern = r"{(.+?)}"
m = re.search(json_pattern, cleaned_output)
if m:
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
response = json.loads(cleaned_output)
clean_str = super().parse_prompt_response(model_out_text);
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
sql, thoughts = response["sql"], response["thoughts"]
return SqlAction(sql, thoughts)

View File

@@ -1,36 +1,28 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.out_parser import DbChatOutputParser, SqlAction
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
PROMPT_SUFFIX = """Only use the following tables:
{table_info}
Question: {input}
"""
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Use as few tables as possible when querying.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
PROMPT_SUFFIX = """Only use the following tables generate sql:
{table_info}
Question: {input}
"""
@@ -48,14 +40,19 @@ RESPONSE_FORMAT = {
"sql": "SQL Query to run",
}
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
chat_db_prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDb.value,
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value,
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
@@ -63,5 +60,5 @@ chat_db_prompt = PromptTemplate(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})

View File

@@ -1,279 +0,0 @@
import requests
import datetime
import threading
import json
import traceback
from urllib.parse import urljoin
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from typing import Any, Iterable, List, Optional
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.scene.chat_db.out_parser import SqlAction
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_db.prompt import chat_db_prompt
from pilot.out_parser.base import BaseOutputParser
from pilot.scene.chat_db.out_parser import DbChatOutputParser
CFG = Config()
class ChatWithDb(BaseChat):
chat_scene: str = ChatScene.ChatWithDb.value
"""Number of results to return from the query"""
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def call(self) -> str:
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
# "stop": self.sep_style,
}
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now()
# TODO
self.current_message.tokens = 0
current_prompt = self.prompt_template.format(**input_values)
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
self.current_message.add_system_message(current_prompt)
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
"temperature": float(self.temperature),
"max_new_tokens": int(self.max_new_tokens),
"stop": self.prompt_template.sep,
}
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
### 走非流式的模型服务接口
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
ai_response_text = (
self.prompt_template.output_parser.parse_model_server_out(response)
)
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.database.run(self.db_connect, prompt_define_response.sql)
if hasattr(prompt_define_response, "thoughts"):
if prompt_define_response.thoughts.get("speak"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("speak"), result
)
)
elif prompt_define_response.thoughts.get("reasoning"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("reasoning"), result
)
)
else:
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts, result
)
)
else:
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response, result
)
)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
def chat_show(self):
ret = []
# 单论对话只能有一次User 记录 和一次 AI 记录
# TODO 推理过程前端展示。。。
for message in self.current_message.messages:
if isinstance(message, HumanMessage):
ret[-1][-2] = message.content
# 是否展示推理过程
if isinstance(message, ViewMessage):
ret[-1][-1] = message.content
return ret
# 暂时为了兼容前端
def current_ai_response(self) -> str:
for message in self.current_message.messages:
if message.type == "view":
return message.content
return None
def generate_llm_text(self) -> str:
text = self.prompt_template.template_define + self.prompt_template.sep
### 线处理历史信息
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage):
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
if not isinstance(message, ViewMessage):
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return text
@property
def chat_type(self) -> str:
return ChatScene.ChatExecution.value
if __name__ == "__main__":
# chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息")
#
# chat.call()
#
# resp = chat.chat_show()
#
# print(vars(resp))
# memory = FileHistoryMemory("test123")
# once1 = OnceConversation()
# once1.add_user_message("问题测试")
# once1.add_system_message("prompt1")
# once1.add_system_message("prompt2")
# once1.chat_order = 1
# once1.set_start_time(datetime.datetime.now())
# memory.append(once1)
#
# once = OnceConversation()
# once.add_user_message("问题测试2")
# once.add_system_message("prompt3")
# once.add_system_message("prompt4")
# once.chat_order = 2
# once.set_start_time(datetime.datetime.now())
# memory.append(once)
db: Database = CFG.local_db
db_connect = db.get_session("gpt-user")
data = db.run(db_connect, "select * from users")
print(generate_htm_table(data))
#
# print(db.run(db_connect, "select * from users"))
#
# #
# # def print_numbers():
# # db_connect1 = db.get_session("dbgpt-test")
# # cursor1 = db_connect1.execute(text("select * from test_name"))
# # if cursor1.returns_rows:
# # result1 = cursor1.fetchall()
# # print( result1)
# #
# #
# # # 创建线程
# # t = threading.Thread(target=print_numbers)
# # # 启动线程
# # t.start()
#
# print(db.run(db_connect, "select * from tran_order"))
#
# print(db.run(db_connect, "select count(*) as aa from tran_order"))
#
# print(db.table_simple_info(db_connect))
# my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# index = 3
# last_three_elements = my_list[-index:]
# print(last_three_elements)

View File

@@ -0,0 +1,56 @@
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.professional_qa.prompt import prompt
CFG = Config()
class ChatWithDbQA(BaseChat):
chat_scene: str = ChatScene.ChatWithDbQA.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.db_name = db_name
if db_name:
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def generate_input_values(self):
table_info = ""
dialect = "mysql"
if self.db_name:
table_info = self.database.table_simple_info(self.db_connect)
dialect = self.database.dialect
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": dialect,
"table_info": table_info
}
return input_values
def do_with_prompt_response(self, prompt_response):
if self.auto_execute:
return self.database.run(self.db_connect, prompt_response.sql)
else:
return prompt_response

View File

@@ -0,0 +1,19 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,48 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """
PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
{table_info}
Question: {input}
"""
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbQA.value,
input_variables=["input", "table_info", "dialect", "top_k"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX ,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@@ -1,24 +1,72 @@
import requests
import datetime
from urllib.parse import urljoin
from typing import List
import traceback
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.message import OnceConversation
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.commands.command import execute_command
from pilot.prompts.generator import PluginPromptGenerator
from pilot.scene.chat_execution.prompt import prompt
CFG = Config()
class ChatWithPlugin(BaseChat):
chat_scene: str = ChatScene.ChatExecution.value
plugins_prompt_generator:PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_mode, chat_session_id, current_user_input):
super().__init__(chat_mode, chat_session_id, current_user_input)
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None):
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令
self.select_plugin = plugin_selector
if self.select_plugin:
for plugin in CFG.plugins:
if plugin._name == plugin_selector :
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
def call(self):
super().call()
else:
for plugin in CFG.plugins:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)),
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
}
return input_values
def do_with_prompt_response(self, prompt_response):
## plugin command run
return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator)
def chat_show(self):
super().chat_show()
def _load_history(self, session_id: str) -> List[OnceConversation]:
return super()._load_history(session_id)
def __list_to_prompt_str(self, list: List) -> str:
if list:
separator = '\n'
return separator.join(list)
else:
return ""
def generate(self, p) -> str:
return super().generate(p)

View File

@@ -0,0 +1,33 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple):
command: Dict
thoughts: Dict
class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text))
command, thoughts = response["command"], response["thoughts"]
return PluginAction(command, thoughts)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}" )
view_text = f"##### {speak}" + "\n" + str(data)
return view_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,66 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
PROMPT_SUFFIX = """
Goals:
{input}
"""
_DEFAULT_TEMPLATE = """
Constraints:
Exclusively use the commands listed in double quotes e.g. "command name"
Reflect on past decisions and strategies to refine your approach.
Constructively self-criticize your big-picture behavior constantly.
{constraints}
Commands:
{commands_infos}
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@@ -1,8 +1,14 @@
from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton
from pilot.scene.chat_db.chat import ChatWithDb
import inspect
import importlib
from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
class ChatFactory(metaclass=Singleton):
@staticmethod
@@ -13,5 +19,5 @@ class ChatFactory(metaclass=Singleton):
if cls.chat_scene == chat_mode:
implementation = cls(**kwargs)
if implementation == None:
raise Exception("Invalid implementation name:" + chat_mode)
raise Exception(f"Invalid implementation name:{chat_mode}")
return implementation

View File

@@ -0,0 +1,69 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatNewKnowledge (BaseChat):
chat_scene: str = ChatScene.ChatNewKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNewKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.knowledge_name = knowledge_name
vector_store_config = {
"vector_store_name": knowledge_name,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value

View File

@@ -0,0 +1,19 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,43 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
已知内容:
{context}
问题:
{question}
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatNewKnowledge.value,
input_variables=["context", "question"],
response_format=None,
template_define=None,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@@ -0,0 +1,66 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatDefaultKnowledge (BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value

View File

@@ -0,0 +1,19 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,43 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
已知内容:
{context}
问题:
{question}
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatKnowledge.value,
input_variables=["context", "question"],
response_format=None,
template_define=None,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@@ -0,0 +1,71 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatUrlKnowledge (BaseChat):
chat_scene: str = ChatScene.ChatUrlKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatUrlKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.url = url
vector_store_config = {
"vector_store_name": url,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path=url,
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
# url soruce in vector
self.knowledge_embedding_client.knowledge_embedding()
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value

View File

@@ -0,0 +1,19 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,43 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
已知内容:
{context}
问题:
{question}
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatUrlKnowledge.value,
input_variables=["context", "question"],
response_format=None,
template_define=None,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@@ -0,0 +1,43 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_normal.prompt import prompt
CFG = Config()
class ChatNormal(BaseChat):
chat_scene: str = ChatScene.ChatNormal.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id,
current_user_input=user_input)
def generate_input_values(self):
input_values = {
"input": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatNormal.value

View File

@@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass

View File

@@ -1,31 +1,33 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
def stream_write_and_read(lst):
# 对lst使用yield from进行可迭代对象的扁平化
yield from lst
while True:
val = yield
lst.append(val)
CFG = Config()
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatNormal.value,
input_variables=["input"],
response_format=None,
template_define=None,
template=None,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
if __name__ == "__main__":
# 创建一个空列表
my_list = []
CFG.prompt_templates.update({prompt.template_scene: prompt})
# 使用生成器写入数据
stream_writer = stream_write_and_read(my_list)
next(stream_writer)
stream_writer.send(10)
print(1)
stream_writer.send(20)
print(2)
stream_writer.send(30)
print(3)
# 使用生成器读取数据
stream_reader = stream_write_and_read(my_list)
next(stream_reader)
print(stream_reader.send(None))
print(stream_reader.send(None))
print(stream_reader.send(None))

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import traceback
import argparse
import datetime
import json
@@ -9,7 +9,6 @@ import shutil
import sys
import time
import uuid
from urllib.parse import urljoin
import gradio as gr
import requests
@@ -30,18 +29,19 @@ from pilot.configs.model_config import (
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.connections.mysql import MySQLOperator
from pilot.conversation import (
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_sql_mode,
conversation_types,
chat_mode_title,
default_conversation,
)
from pilot.plugins import scan_plugins
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.common.plugins import scan_plugins
from pilot.prompts.generator import PluginPromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
@@ -95,6 +95,11 @@ default_knowledge_base_dialogue = get_lang_text(
add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
url_knowledge_dialogue = get_lang_text(
"knowledge_qa_type_url_knowledge_dialogue"
)
knowledge_qa_type_list = [
llm_native_dialogue,
default_knowledge_base_dialogue,
@@ -111,19 +116,19 @@ def get_simlar(q):
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(**DB_SETTINGS)
message = ""
schemas = mo.get_schema(dbname)
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s["schema_info"] + ";"
message += s+ ";"
return get_lang_text("sql_schema_info").format(dbname, message)
def get_database_list():
mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list()
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
get_window_url_params = """
@@ -210,285 +215,127 @@ def post_process_code(code):
return code
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
if mode == conversation_types["default_knownledge"] and not db_selector:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"] and not db_selector:
return ChatScene.ChatNewKnowledge
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
return ChatScene.ChatWithDb
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title['chat_use_plugin'] == selected:
return ChatScene.ChatExecution
elif chat_mode_title['knowledge_qa'] == selected:
mode= param
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"]:
return ChatScene.ChatNewKnowledge
elif mode == conversation_types["url"]:
return ChatScene.ChatUrlKnowledge
else:
return ChatScene.ChatNormal
else:
return ChatScene.ChatNormal
sql_mode= param
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
return ChatScene.ChatWithDbExecute
else:
return ChatScene.ChatWithDbQA
def chatbot_callback(state, message):
print(f"chatbot_callback:{message}")
state.messages[-1][-1] = f"{message}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
def http_bot(
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name
):
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
start_tstamp = time.time()
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
if ChatScene.ChatWithDb == scene:
logger.info("基于DB对话走新的模式")
logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
if chat_mode_title['knowledge_qa'] == selected:
scene: ChatScene = get_chat_mode(selected, mode)
elif chat_mode_title['chat_use_plugin'] == selected:
scene: ChatScene = get_chat_mode(selected)
else:
scene: ChatScene = get_chat_mode(selected, sql_mode)
print(f"chat scene:{scene.value}")
if ChatScene.ChatWithDbExecute == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatWithDbQA == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(
fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname)
)
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
new_state.append_message(role="USER", message=system_prompt)
# new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
else:
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
new_state.conv_id = uuid.uuid4().hex
state = new_state
else:
### 后续对话
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if mode == conversation_types["custome"]:
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
## 获取最后一次插件的返回
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
state.messages[0][0] = ""
state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep
if state.sep_style == SeparatorStyle.SINGLE
else state.sep2,
elif ChatScene.ChatExecution == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
logger.info(f"Requert: \n{payload}")
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNormal == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNewKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"knowledge_name": knowledge_name
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatUrlKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"url": url_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
print(response.json())
print(str(response))
try:
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
if respObj_ex["error_code"] == 0:
ai_response = None
all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(state.sep)
last_index = -1
for i in range(len(tmpResp)):
if tmpResp[i].find("ASSISTANT:") != -1:
last_index = i
ai_response = tmpResp[last_index]
ai_response = ai_response.replace("ASSISTANT:", "")
ai_response = ai_response.replace("\n", "")
ai_response = ai_response.replace("\_", "_")
print(ai_response)
if ai_response == None:
state.messages[-1][-1] = "ASSISTANT未能正确回复回复结果为:\n" + all_text
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
plugin_resp = execute_ai_response_json(
auto_prompt.prompt_generator, ai_response
)
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = (
"Model推理信息:\n"
+ ai_response
+ "\n\nDB-GPT执行结果:\n"
+ plugin_resp
)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = (
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
# 流式输出
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
print("****************:", data)
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
if not chat.prompt_template.stream_out:
logger.info("not stream out, wait model response!")
state.messages[-1][-1] = chat.nostream_call()
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
logger.info("stream out start!")
try:
stream_gen = chat.stream_call()
for msg in stream_gen:
state.messages[-1][-1] = msg
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except Exception as e:
print(traceback.format_exc())
state.messages[-1][-1] = "Error:" + str(e)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
if state.messages[-1][-1].endwith(""):
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
block_css = (
code_highlight_css
@@ -515,15 +362,12 @@ def change_sql_mode(sql_mode):
def change_mode(mode):
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
return gr.update(visible=False)
else:
if mode in [add_knowledge_base_dialogue]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_tab():
autogpt = True
def build_single_model_ui():
notice_markdown = get_lang_text("db_gpt_introduction")
@@ -552,7 +396,16 @@ def build_single_model_ui():
interactive=True,
label=get_lang_text("max_input_token_size"),
)
tabs = gr.Tabs()
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
return evt.value
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
tabs.select(on_select, None, selected)
with tabs:
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
with tab_sql:
@@ -572,11 +425,34 @@ def build_single_model_ui():
get_lang_text("sql_generate_mode_none"),
],
show_label=False,
value=get_lang_text("sql_generate_mode_none"),
value=get_lang_text("sql_generate_mode_none")
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
print("tab_plugin in...")
with gr.Row(elem_id="plugin_selector"):
# TODO
plugin_selector = gr.Dropdown(
label=get_lang_text("select_plugin"),
choices=list(plugins_select_info().keys()),
value="",
interactive=True,
show_label=True,
type="value"
).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
print(f"user plugin:{plugins_select_info().get(evt.value)}")
return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
with tab_qa:
mode = gr.Radio(
@@ -584,14 +460,25 @@ def build_single_model_ui():
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
url_knowledge_dialogue,
],
show_label=False,
value=llm_native_dialogue,
)
vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False
get_lang_text("configure_knowledge_base"), open=False, visible=False
)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
def show_url_input(evt:gr.SelectData):
if evt.value == url_knowledge_dialogue:
return gr.update(visible=True)
else:
return gr.update(visible=False)
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
with vs_setting:
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
@@ -639,10 +526,14 @@ def build_single_model_ui():
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
gr.Markdown(learn_more_markdown)
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@@ -651,7 +542,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, temperature, max_output_tokens]+ params,
[state, chatbot] + btn_list,
)
@@ -659,7 +550,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, temperature, max_output_tokens]+ params,
[state, chatbot] + btn_list,
)
vs_add.click(
@@ -766,8 +657,8 @@ if __name__ == "__main__":
# 加载插件可执行命令
command_categories = [
"pilot.commands.audio_text",
"pilot.commands.image_gen",
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# 排除禁用命令
command_categories = [

View File

View File

@@ -11,6 +11,7 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
@@ -61,6 +62,12 @@ class KnowledgeEmbedding:
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
return embedding

View File

@@ -1,7 +1,7 @@
from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore
# from pilot.vector_store.milvus_store import MilvusStore
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
connector = {"Chroma": ChromaStore, "Milvus": None}
class VectorStoreConnector: