mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
add plugin mode
This commit is contained in:
@@ -55,54 +55,6 @@ def fix_and_parse_json(
|
|||||||
logger.error("参数解析错误", e)
|
logger.error("参数解析错误", e)
|
||||||
|
|
||||||
|
|
||||||
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
|
|
||||||
"""Fix the given JSON string to make it parseable and fully compliant with two techniques.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_string (str): The JSON string to fix.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The fixed JSON string.
|
|
||||||
"""
|
|
||||||
assistant_reply = assistant_reply.strip()
|
|
||||||
if assistant_reply.startswith("```json"):
|
|
||||||
assistant_reply = assistant_reply[7:]
|
|
||||||
if assistant_reply.endswith("```"):
|
|
||||||
assistant_reply = assistant_reply[:-3]
|
|
||||||
try:
|
|
||||||
return json.loads(assistant_reply) # just check the validity
|
|
||||||
except json.JSONDecodeError as e: # noqa: E722
|
|
||||||
print(f"JSONDecodeError: {e}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
if assistant_reply.startswith("json "):
|
|
||||||
assistant_reply = assistant_reply[5:]
|
|
||||||
assistant_reply = assistant_reply.strip()
|
|
||||||
try:
|
|
||||||
return json.loads(assistant_reply) # just check the validity
|
|
||||||
except json.JSONDecodeError: # noqa: E722
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Parse and print Assistant response
|
|
||||||
assistant_reply_json = fix_and_parse_json(assistant_reply)
|
|
||||||
logger.debug("Assistant reply JSON: %s", str(assistant_reply_json))
|
|
||||||
if assistant_reply_json == {}:
|
|
||||||
assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets(
|
|
||||||
assistant_reply
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json))
|
|
||||||
if assistant_reply_json != {}:
|
|
||||||
return assistant_reply_json
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
"Error: The following AI output couldn't be converted to a JSON:\n",
|
|
||||||
assistant_reply,
|
|
||||||
)
|
|
||||||
if CFG.speak_mode:
|
|
||||||
say_text("I have received an invalid JSON response from the OpenAI API.")
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def correct_json(json_to_load: str) -> str:
|
def correct_json(json_to_load: str) -> str:
|
||||||
|
@@ -4,10 +4,9 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques
|
|
||||||
from pilot.commands.exception_not_commands import NotCommands
|
from pilot.commands.exception_not_commands import NotCommands
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.prompts.generator import PromptGenerator
|
from pilot.prompts.generator import PluginPromptGenerator
|
||||||
from pilot.speech import say_text
|
from pilot.speech import say_text
|
||||||
|
|
||||||
|
|
||||||
@@ -24,8 +23,8 @@ def _resolve_pathlike_command_args(command_args):
|
|||||||
|
|
||||||
|
|
||||||
def execute_ai_response_json(
|
def execute_ai_response_json(
|
||||||
prompt: PromptGenerator,
|
prompt: PluginPromptGenerator,
|
||||||
ai_response: str,
|
ai_response,
|
||||||
user_input: str = None,
|
user_input: str = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -39,11 +38,8 @@ def execute_ai_response_json(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
try:
|
|
||||||
assistant_reply_json = fix_json_using_multiple_techniques(ai_response)
|
command_name, arguments = get_command(ai_response)
|
||||||
except (json.JSONDecodeError, ValueError, AttributeError) as e:
|
|
||||||
raise NotCommands("非可执行命令结构")
|
|
||||||
command_name, arguments = get_command(assistant_reply_json)
|
|
||||||
|
|
||||||
if cfg.speak_mode:
|
if cfg.speak_mode:
|
||||||
say_text(f"I want to execute {command_name}")
|
say_text(f"I want to execute {command_name}")
|
||||||
@@ -71,7 +67,7 @@ def execute_ai_response_json(
|
|||||||
def execute_command(
|
def execute_command(
|
||||||
command_name: str,
|
command_name: str,
|
||||||
arguments,
|
arguments,
|
||||||
prompt: PromptGenerator,
|
prompt: PluginPromptGenerator,
|
||||||
):
|
):
|
||||||
"""Execute the command and return the result
|
"""Execute the command and return the result
|
||||||
|
|
||||||
|
@@ -1,21 +1,21 @@
|
|||||||
import markdown2
|
import markdown2
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
def datas_to_table_html(data):
|
def datas_to_table_html(data):
|
||||||
df = pd.DataFrame(data[1:], columns=data[0])
|
df = pd.DataFrame(data[1:], columns=data[0])
|
||||||
table_style = """<style>
|
table_style = """<style>
|
||||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||||
</style>"""
|
</style>"""
|
||||||
html_table = df.to_html(index=False, escape=False)
|
html_table = df.to_html(index=False, escape=False)
|
||||||
|
|
||||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
||||||
|
|
||||||
return html.replace("\n", " ")
|
return html.replace("\n", " ")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_markdown_table(data):
|
def generate_markdown_table(data):
|
||||||
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n """
|
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n"""
|
||||||
# 获取表格列数
|
# 获取表格列数
|
||||||
num_cols = len(data[0])
|
num_cols = len(data[0])
|
||||||
# 生成表头
|
# 生成表头
|
||||||
@@ -41,6 +41,7 @@ def generate_markdown_table(data):
|
|||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
def generate_htm_table(data):
|
def generate_htm_table(data):
|
||||||
markdown_text = generate_markdown_table(data)
|
markdown_text = generate_markdown_table(data)
|
||||||
html_table = markdown2.markdown(markdown_text, extras=["tables"])
|
html_table = markdown2.markdown(markdown_text, extras=["tables"])
|
||||||
@@ -53,4 +54,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
|
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
|
||||||
|
|
||||||
print(table_style.replace("\n", " "))
|
print(table_style.replace("\n", " "))
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
class SeparatorStyle(Enum):
|
||||||
SINGLE ="###"
|
SINGLE = "###"
|
||||||
TWO = "</s>"
|
TWO = "</s>"
|
||||||
THREE = auto()
|
THREE = auto()
|
||||||
FOUR = auto()
|
FOUR = auto()
|
||||||
|
@@ -30,16 +30,16 @@ class Database:
|
|||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
engine,
|
engine,
|
||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
metadata: Optional[MetaData] = None,
|
metadata: Optional[MetaData] = None,
|
||||||
ignore_tables: Optional[List[str]] = None,
|
ignore_tables: Optional[List[str]] = None,
|
||||||
include_tables: Optional[List[str]] = None,
|
include_tables: Optional[List[str]] = None,
|
||||||
sample_rows_in_table_info: int = 3,
|
sample_rows_in_table_info: int = 3,
|
||||||
indexes_in_table_info: bool = False,
|
indexes_in_table_info: bool = False,
|
||||||
custom_table_info: Optional[dict] = None,
|
custom_table_info: Optional[dict] = None,
|
||||||
view_support: bool = False,
|
view_support: bool = False,
|
||||||
):
|
):
|
||||||
"""Create engine from database URI."""
|
"""Create engine from database URI."""
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
@@ -119,7 +119,7 @@ class Database:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri(
|
def from_uri(
|
||||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
) -> Database:
|
) -> Database:
|
||||||
"""Construct a SQLAlchemy engine from URI."""
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
_engine_args = engine_args or {}
|
_engine_args = engine_args or {}
|
||||||
@@ -148,7 +148,7 @@ class Database:
|
|||||||
|
|
||||||
self._metadata = MetaData()
|
self._metadata = MetaData()
|
||||||
# sql = f"use {db_name}"
|
# sql = f"use {db_name}"
|
||||||
sql = text(f'use `{db_name}`')
|
sql = text(f"use `{db_name}`")
|
||||||
session.execute(sql)
|
session.execute(sql)
|
||||||
|
|
||||||
# 处理表信息数据
|
# 处理表信息数据
|
||||||
@@ -159,13 +159,17 @@ class Database:
|
|||||||
# tables list if view_support is True
|
# tables list if view_support is True
|
||||||
self._all_tables = set(
|
self._all_tables = set(
|
||||||
self._inspector.get_table_names(schema=db_name)
|
self._inspector.get_table_names(schema=db_name)
|
||||||
+ (self._inspector.get_view_names(schema=db_name) if self.view_support else [])
|
+ (
|
||||||
|
self._inspector.get_view_names(schema=db_name)
|
||||||
|
if self.view_support
|
||||||
|
else []
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def get_current_db_name(self, session) -> str:
|
def get_current_db_name(self, session) -> str:
|
||||||
return session.execute(text('SELECT DATABASE()')).scalar()
|
return session.execute(text("SELECT DATABASE()")).scalar()
|
||||||
|
|
||||||
def table_simple_info(self, session):
|
def table_simple_info(self, session):
|
||||||
_sql = f"""
|
_sql = f"""
|
||||||
@@ -201,7 +205,7 @@ class Database:
|
|||||||
tbl
|
tbl
|
||||||
for tbl in self._metadata.sorted_tables
|
for tbl in self._metadata.sorted_tables
|
||||||
if tbl.name in set(all_table_names)
|
if tbl.name in set(all_table_names)
|
||||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||||
]
|
]
|
||||||
|
|
||||||
tables = []
|
tables = []
|
||||||
@@ -214,7 +218,7 @@ class Database:
|
|||||||
create_table = str(CreateTable(table).compile(self._engine))
|
create_table = str(CreateTable(table).compile(self._engine))
|
||||||
table_info = f"{create_table.rstrip()}"
|
table_info = f"{create_table.rstrip()}"
|
||||||
has_extra_info = (
|
has_extra_info = (
|
||||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||||
)
|
)
|
||||||
if has_extra_info:
|
if has_extra_info:
|
||||||
table_info += "\n\n/*"
|
table_info += "\n\n/*"
|
||||||
@@ -303,6 +307,10 @@ class Database:
|
|||||||
|
|
||||||
def get_database_list(self):
|
def get_database_list(self):
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
cursor = session.execute(text(' show databases;'))
|
cursor = session.execute(text(" show databases;"))
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]]
|
return [
|
||||||
|
d[0]
|
||||||
|
for d in results
|
||||||
|
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
]
|
||||||
|
@@ -1,167 +0,0 @@
|
|||||||
# sourcery skip: do-not-use-staticmethod
|
|
||||||
"""
|
|
||||||
A module that contains the AIConfig class object that contains the configuration
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import distro
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.prompts.generator import PromptGenerator
|
|
||||||
from pilot.prompts.prompt import build_default_prompt_generator
|
|
||||||
|
|
||||||
# Soon this will go in a folder where it remembers more stuff about the run(s)
|
|
||||||
SAVE_FILE = str(Path(os.getcwd()) / "ai_settings.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
class AIConfig:
|
|
||||||
"""
|
|
||||||
A class object that contains the configuration information for the AI
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
ai_name (str): The name of the AI.
|
|
||||||
ai_role (str): The description of the AI's role.
|
|
||||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
|
||||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ai_name: str = "",
|
|
||||||
ai_role: str = "",
|
|
||||||
ai_goals: list | None = None,
|
|
||||||
api_budget: float = 0.0,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initialize a class instance
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
ai_name (str): The name of the AI.
|
|
||||||
ai_role (str): The description of the AI's role.
|
|
||||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
|
||||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
if ai_goals is None:
|
|
||||||
ai_goals = []
|
|
||||||
self.ai_name = ai_name
|
|
||||||
self.ai_role = ai_role
|
|
||||||
self.ai_goals = ai_goals
|
|
||||||
self.api_budget = api_budget
|
|
||||||
self.prompt_generator = None
|
|
||||||
self.command_registry = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(config_file: str = SAVE_FILE) -> "AIConfig":
|
|
||||||
"""
|
|
||||||
Returns class object with parameters (ai_name, ai_role, ai_goals, api_budget) loaded from
|
|
||||||
yaml file if yaml file exists,
|
|
||||||
else returns class with no parameters.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config_file (int): The path to the config yaml file.
|
|
||||||
DEFAULT: "../ai_settings.yaml"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
cls (object): An instance of given cls object
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(config_file, encoding="utf-8") as file:
|
|
||||||
config_params = yaml.load(file, Loader=yaml.FullLoader)
|
|
||||||
except FileNotFoundError:
|
|
||||||
config_params = {}
|
|
||||||
|
|
||||||
ai_name = config_params.get("ai_name", "")
|
|
||||||
ai_role = config_params.get("ai_role", "")
|
|
||||||
ai_goals = [
|
|
||||||
str(goal).strip("{}").replace("'", "").replace('"', "")
|
|
||||||
if isinstance(goal, dict)
|
|
||||||
else str(goal)
|
|
||||||
for goal in config_params.get("ai_goals", [])
|
|
||||||
]
|
|
||||||
api_budget = config_params.get("api_budget", 0.0)
|
|
||||||
# type is Type[AIConfig]
|
|
||||||
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
|
|
||||||
|
|
||||||
def save(self, config_file: str = SAVE_FILE) -> None:
|
|
||||||
"""
|
|
||||||
Saves the class parameters to the specified file yaml file path as a yaml file.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config_file(str): The path to the config yaml file.
|
|
||||||
DEFAULT: "../ai_settings.yaml"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"ai_name": self.ai_name,
|
|
||||||
"ai_role": self.ai_role,
|
|
||||||
"ai_goals": self.ai_goals,
|
|
||||||
"api_budget": self.api_budget,
|
|
||||||
}
|
|
||||||
with open(config_file, "w", encoding="utf-8") as file:
|
|
||||||
yaml.dump(config, file, allow_unicode=True)
|
|
||||||
|
|
||||||
def construct_full_prompt(
|
|
||||||
self, prompt_generator: Optional[PromptGenerator] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Returns a prompt to the user with the class information in an organized fashion.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
None
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
full_prompt (str): A string containing the initial prompt for the user
|
|
||||||
including the ai_name, ai_role, ai_goals, and api_budget.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_start = (
|
|
||||||
"Your decisions must always be made independently without"
|
|
||||||
" seeking user assistance. Play to your strengths as an LLM and pursue"
|
|
||||||
" simple strategies with no legal complications."
|
|
||||||
""
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = Config()
|
|
||||||
if prompt_generator is None:
|
|
||||||
prompt_generator = build_default_prompt_generator()
|
|
||||||
prompt_generator.goals = self.ai_goals
|
|
||||||
prompt_generator.name = self.ai_name
|
|
||||||
prompt_generator.role = self.ai_role
|
|
||||||
prompt_generator.command_registry = self.command_registry
|
|
||||||
for plugin in cfg.plugins:
|
|
||||||
if not plugin.can_handle_post_prompt():
|
|
||||||
continue
|
|
||||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
|
||||||
|
|
||||||
if cfg.execute_local_commands:
|
|
||||||
# add OS info to prompt
|
|
||||||
os_name = platform.system()
|
|
||||||
os_info = (
|
|
||||||
platform.platform(terse=True)
|
|
||||||
if os_name != "Linux"
|
|
||||||
else distro.name(pretty=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_start += f"\nThe OS you are running on is: {os_info}"
|
|
||||||
|
|
||||||
# Construct full prompt
|
|
||||||
full_prompt = f"You are {prompt_generator.name}, {prompt_generator.role}\n{prompt_start}\n\nGOALS:\n\n"
|
|
||||||
for i, goal in enumerate(self.ai_goals):
|
|
||||||
full_prompt += f"{i+1}. {goal}\n"
|
|
||||||
if self.api_budget > 0.0:
|
|
||||||
full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
|
||||||
self.prompt_generator = prompt_generator
|
|
||||||
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
|
|
||||||
return full_prompt
|
|
@@ -47,7 +47,6 @@ class Config(metaclass=Singleton):
|
|||||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||||
|
|
||||||
|
|
||||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||||
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
|
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
|
||||||
@@ -104,8 +103,17 @@ class Config(metaclass=Singleton):
|
|||||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||||
|
|
||||||
### TODO Adapt to multiple types of libraries
|
### TODO Adapt to multiple types of libraries
|
||||||
self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) ,
|
self.local_db = Database.from_uri(
|
||||||
engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True})
|
"mysql+pymysql://"
|
||||||
|
+ self.LOCAL_DB_USER
|
||||||
|
+ ":"
|
||||||
|
+ self.LOCAL_DB_PASSWORD
|
||||||
|
+ "@"
|
||||||
|
+ self.LOCAL_DB_HOST
|
||||||
|
+ ":"
|
||||||
|
+ str(self.LOCAL_DB_PORT),
|
||||||
|
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||||
|
)
|
||||||
|
|
||||||
### LLM Model Service Configuration
|
### LLM Model Service Configuration
|
||||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||||
|
@@ -2,7 +2,34 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
"""We need to design a base class. That other connector can Write with this"""
|
"""We need to design a base class. That other connector can Write with this"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseConnection:
|
class BaseConnect(BaseModel, ABC):
|
||||||
pass
|
type
|
||||||
|
driver: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(self, db_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_table_names(self) -> Iterable[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_database_list(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||||
|
pass
|
@@ -1,64 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import pymysql
|
|
||||||
|
|
||||||
class MySQLOperator:
|
|
||||||
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
|
||||||
Args:
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
"""
|
|
||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
|
||||||
|
|
||||||
def __init__(self, user, password, host="localhost", port=3306) -> None:
|
|
||||||
self.conn = pymysql.connect(
|
|
||||||
host=host,
|
|
||||||
user=user,
|
|
||||||
port=port,
|
|
||||||
passwd=password,
|
|
||||||
charset="utf8mb4",
|
|
||||||
cursorclass=pymysql.cursors.DictCursor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_schema(self, schema_name):
|
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
_sql = f"""
|
|
||||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME;
|
|
||||||
"""
|
|
||||||
cursor.execute(_sql)
|
|
||||||
results = cursor.fetchall()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def run_sql(self, db_name:str, sql:str, fetch: str = "all"):
|
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute("USE " + db_name)
|
|
||||||
cursor.execute(sql)
|
|
||||||
if fetch == "all":
|
|
||||||
result = cursor.fetchall()
|
|
||||||
elif fetch == "one":
|
|
||||||
result = cursor.fetchone()[0] # type: ignore
|
|
||||||
else:
|
|
||||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
|
||||||
return str(result)
|
|
||||||
|
|
||||||
def get_index(self, schema_name):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_db_list(self):
|
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
_sql = """
|
|
||||||
show databases;
|
|
||||||
"""
|
|
||||||
cursor.execute(_sql)
|
|
||||||
results = cursor.fetchall()
|
|
||||||
|
|
||||||
dbs = [
|
|
||||||
d["Database"] for d in results if d["Database"] not in self.default_db
|
|
||||||
]
|
|
||||||
return dbs
|
|
||||||
|
|
||||||
def get_meta(self, schema_name):
|
|
||||||
pass
|
|
0
pilot/connections/rdbms/__init__.py
Normal file
0
pilot/connections/rdbms/__init__.py
Normal file
18
pilot/connections/rdbms/mysql.py
Normal file
18
pilot/connections/rdbms/mysql.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import pymysql
|
||||||
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLConnect(RDBMSDatabase):
|
||||||
|
"""Connect MySQL Database fetch MetaData For LLM Prompt
|
||||||
|
Args:
|
||||||
|
Usage:
|
||||||
|
"""
|
||||||
|
|
||||||
|
type:str = "MySQL"
|
||||||
|
connect_url = "mysql+pymysql://"
|
||||||
|
|
||||||
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
|
318
pilot/connections/rdbms/rdbms_connect.py
Normal file
318
pilot/connections/rdbms/rdbms_connect.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Iterable, List, Optional
|
||||||
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import (
|
||||||
|
MetaData,
|
||||||
|
Table,
|
||||||
|
create_engine,
|
||||||
|
inspect,
|
||||||
|
select,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.engine import CursorResult, Engine
|
||||||
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
|
from sqlalchemy.schema import CreateTable
|
||||||
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
|
|
||||||
|
from pilot.connections.base import BaseConnect
|
||||||
|
|
||||||
|
|
||||||
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||||
|
return (
|
||||||
|
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||||
|
f' Columns: {str(index["column_names"])}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RDBMSDatabase(BaseConnect):
|
||||||
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine,
|
||||||
|
schema: Optional[str] = None,
|
||||||
|
metadata: Optional[MetaData] = None,
|
||||||
|
ignore_tables: Optional[List[str]] = None,
|
||||||
|
include_tables: Optional[List[str]] = None,
|
||||||
|
sample_rows_in_table_info: int = 3,
|
||||||
|
indexes_in_table_info: bool = False,
|
||||||
|
custom_table_info: Optional[dict] = None,
|
||||||
|
view_support: bool = False,
|
||||||
|
):
|
||||||
|
"""Create engine from database URI."""
|
||||||
|
self._engine = engine
|
||||||
|
self._schema = schema
|
||||||
|
if include_tables and ignore_tables:
|
||||||
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||||
|
|
||||||
|
self._inspector = inspect(self._engine)
|
||||||
|
session_factory = sessionmaker(bind=engine)
|
||||||
|
Session = scoped_session(session_factory)
|
||||||
|
|
||||||
|
self._db_sessions = Session
|
||||||
|
|
||||||
|
self._all_tables = set()
|
||||||
|
self.view_support = False
|
||||||
|
self._usable_tables = set()
|
||||||
|
self._include_tables = set()
|
||||||
|
self._ignore_tables = set()
|
||||||
|
self._custom_table_info = set()
|
||||||
|
self._indexes_in_table_info = set()
|
||||||
|
self._usable_tables = set()
|
||||||
|
self._usable_tables = set()
|
||||||
|
self._sample_rows_in_table_info = set()
|
||||||
|
# including view support by adding the views as well as tables to the all
|
||||||
|
# tables list if view_support is True
|
||||||
|
# self._all_tables = set(
|
||||||
|
# self._inspector.get_table_names(schema=schema)
|
||||||
|
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self._include_tables = set(include_tables) if include_tables else set()
|
||||||
|
# if self._include_tables:
|
||||||
|
# missing_tables = self._include_tables - self._all_tables
|
||||||
|
# if missing_tables:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"include_tables {missing_tables} not found in database"
|
||||||
|
# )
|
||||||
|
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||||
|
# if self._ignore_tables:
|
||||||
|
# missing_tables = self._ignore_tables - self._all_tables
|
||||||
|
# if missing_tables:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"ignore_tables {missing_tables} not found in database"
|
||||||
|
# )
|
||||||
|
# usable_tables = self.get_usable_table_names()
|
||||||
|
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||||
|
|
||||||
|
# if not isinstance(sample_rows_in_table_info, int):
|
||||||
|
# raise TypeError("sample_rows_in_table_info must be an integer")
|
||||||
|
#
|
||||||
|
# self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||||
|
# self._indexes_in_table_info = indexes_in_table_info
|
||||||
|
#
|
||||||
|
# self._custom_table_info = custom_table_info
|
||||||
|
# if self._custom_table_info:
|
||||||
|
# if not isinstance(self._custom_table_info, dict):
|
||||||
|
# raise TypeError(
|
||||||
|
# "table_info must be a dictionary with table names as keys and the "
|
||||||
|
# "desired table info as values"
|
||||||
|
# )
|
||||||
|
# # only keep the tables that are also present in the database
|
||||||
|
# intersection = set(self._custom_table_info).intersection(self._all_tables)
|
||||||
|
# self._custom_table_info = dict(
|
||||||
|
# (table, self._custom_table_info[table])
|
||||||
|
# for table in self._custom_table_info
|
||||||
|
# if table in intersection
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self._metadata = metadata or MetaData()
|
||||||
|
# # # including view support if view_support = true
|
||||||
|
# self._metadata.reflect(
|
||||||
|
# views=view_support,
|
||||||
|
# bind=self._engine,
|
||||||
|
# only=list(self._usable_tables),
|
||||||
|
# schema=self._schema,
|
||||||
|
# )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_uri(
|
||||||
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> RDBMSDatabase:
|
||||||
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
|
_engine_args = engine_args or {}
|
||||||
|
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dialect(self) -> str:
|
||||||
|
"""Return string representation of dialect to use."""
|
||||||
|
return self._engine.dialect.name
|
||||||
|
|
||||||
|
def get_usable_table_names(self) -> Iterable[str]:
|
||||||
|
"""Get names of tables available."""
|
||||||
|
if self._include_tables:
|
||||||
|
return self._include_tables
|
||||||
|
return self._all_tables - self._ignore_tables
|
||||||
|
|
||||||
|
def get_table_names(self) -> Iterable[str]:
|
||||||
|
"""Get names of tables available."""
|
||||||
|
warnings.warn(
|
||||||
|
"This method is deprecated - please use `get_usable_table_names`."
|
||||||
|
)
|
||||||
|
return self.get_usable_table_names()
|
||||||
|
|
||||||
|
def get_session(self, db_name: str):
|
||||||
|
session = self._db_sessions()
|
||||||
|
|
||||||
|
self._metadata = MetaData()
|
||||||
|
# sql = f"use {db_name}"
|
||||||
|
sql = text(f"use `{db_name}`")
|
||||||
|
session.execute(sql)
|
||||||
|
|
||||||
|
# 处理表信息数据
|
||||||
|
|
||||||
|
self._metadata.reflect(bind=self._engine, schema=db_name)
|
||||||
|
|
||||||
|
# including view support by adding the views as well as tables to the all
|
||||||
|
# tables list if view_support is True
|
||||||
|
self._all_tables = set(
|
||||||
|
self._inspector.get_table_names(schema=db_name)
|
||||||
|
+ (
|
||||||
|
self._inspector.get_view_names(schema=db_name)
|
||||||
|
if self.view_support
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_current_db_name(self, session) -> str:
|
||||||
|
return session.execute(text("SELECT DATABASE()")).scalar()
|
||||||
|
|
||||||
|
def table_simple_info(self, session):
|
||||||
|
_sql = f"""
|
||||||
|
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
|
||||||
|
"""
|
||||||
|
cursor = session.execute(text(_sql))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table_info(self) -> str:
|
||||||
|
"""Information about all tables in the database."""
|
||||||
|
return self.get_table_info()
|
||||||
|
|
||||||
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
"""Get information about specified tables.
|
||||||
|
|
||||||
|
Follows best practices as specified in: Rajkumar et al, 2022
|
||||||
|
(https://arxiv.org/abs/2204.00498)
|
||||||
|
|
||||||
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||||
|
appended to each table description. This can increase performance as
|
||||||
|
demonstrated in the paper.
|
||||||
|
"""
|
||||||
|
all_table_names = self.get_usable_table_names()
|
||||||
|
if table_names is not None:
|
||||||
|
missing_tables = set(table_names).difference(all_table_names)
|
||||||
|
if missing_tables:
|
||||||
|
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||||
|
all_table_names = table_names
|
||||||
|
|
||||||
|
meta_tables = [
|
||||||
|
tbl
|
||||||
|
for tbl in self._metadata.sorted_tables
|
||||||
|
if tbl.name in set(all_table_names)
|
||||||
|
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||||
|
]
|
||||||
|
|
||||||
|
tables = []
|
||||||
|
for table in meta_tables:
|
||||||
|
if self._custom_table_info and table.name in self._custom_table_info:
|
||||||
|
tables.append(self._custom_table_info[table.name])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# add create table command
|
||||||
|
create_table = str(CreateTable(table).compile(self._engine))
|
||||||
|
table_info = f"{create_table.rstrip()}"
|
||||||
|
has_extra_info = (
|
||||||
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||||
|
)
|
||||||
|
if has_extra_info:
|
||||||
|
table_info += "\n\n/*"
|
||||||
|
if self._indexes_in_table_info:
|
||||||
|
table_info += f"\n{self._get_table_indexes(table)}\n"
|
||||||
|
if self._sample_rows_in_table_info:
|
||||||
|
table_info += f"\n{self._get_sample_rows(table)}\n"
|
||||||
|
if has_extra_info:
|
||||||
|
table_info += "*/"
|
||||||
|
tables.append(table_info)
|
||||||
|
final_str = "\n\n".join(tables)
|
||||||
|
return final_str
|
||||||
|
|
||||||
|
def _get_sample_rows(self, table: Table) -> str:
|
||||||
|
# build the select command
|
||||||
|
command = select(table).limit(self._sample_rows_in_table_info)
|
||||||
|
|
||||||
|
# save the columns in string format
|
||||||
|
columns_str = "\t".join([col.name for col in table.columns])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get the sample rows
|
||||||
|
with self._engine.connect() as connection:
|
||||||
|
sample_rows_result: CursorResult = connection.execute(command)
|
||||||
|
# shorten values in the sample rows
|
||||||
|
sample_rows = list(
|
||||||
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
||||||
|
)
|
||||||
|
|
||||||
|
# save the sample rows in string format
|
||||||
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
||||||
|
|
||||||
|
# in some dialects when there are no rows in the table a
|
||||||
|
# 'ProgrammingError' is returned
|
||||||
|
except ProgrammingError:
|
||||||
|
sample_rows_str = ""
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
||||||
|
f"{columns_str}\n"
|
||||||
|
f"{sample_rows_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_table_indexes(self, table: Table) -> str:
|
||||||
|
indexes = self._inspector.get_indexes(table.name)
|
||||||
|
indexes_formatted = "\n".join(map(_format_index, indexes))
|
||||||
|
return f"Table Indexes:\n{indexes_formatted}"
|
||||||
|
|
||||||
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
|
"""Get information about specified tables."""
|
||||||
|
try:
|
||||||
|
return self.get_table_info(table_names)
|
||||||
|
except ValueError as e:
|
||||||
|
"""Format the error message"""
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||||
|
"""Execute a SQL command and return a string representing the results."""
|
||||||
|
cursor = session.execute(text(command))
|
||||||
|
if cursor.returns_rows:
|
||||||
|
if fetch == "all":
|
||||||
|
result = cursor.fetchall()
|
||||||
|
elif fetch == "one":
|
||||||
|
result = cursor.fetchone()[0] # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||||
|
field_names = tuple(i[0:] for i in cursor.keys())
|
||||||
|
|
||||||
|
result = list(result)
|
||||||
|
result.insert(0, field_names)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||||
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
|
If the statement returns rows, a string of the results is returned.
|
||||||
|
If the statement returns no rows, an empty string is returned.
|
||||||
|
|
||||||
|
If the statement throws an error, the error message is returned.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.run(session, command, fetch)
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
"""Format the error message"""
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
def get_database_list(self):
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(" show databases;"))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return [
|
||||||
|
d[0]
|
||||||
|
for d in results
|
||||||
|
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
]
|
@@ -44,6 +44,7 @@ class Conversation:
|
|||||||
skip_next: bool = False
|
skip_next: bool = False
|
||||||
conv_id: Any = None
|
conv_id: Any = None
|
||||||
last_user_input: Any = None
|
last_user_input: Any = None
|
||||||
|
|
||||||
def get_prompt(self):
|
def get_prompt(self):
|
||||||
if self.sep_style == SeparatorStyle.SINGLE:
|
if self.sep_style == SeparatorStyle.SINGLE:
|
||||||
ret = self.system + self.sep
|
ret = self.system + self.sep
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator, validator,Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@@ -17,13 +17,9 @@ from typing import (
|
|||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatHistoryMemory(ABC):
|
class BaseChatHistoryMemory(ABC):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.conversations:List[OnceConversation] = []
|
self.conversations: List[OnceConversation] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def messages(self) -> List[OnceConversation]: # type: ignore
|
def messages(self) -> List[OnceConversation]: # type: ignore
|
||||||
@@ -33,8 +29,6 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
def append(self, message: OnceConversation) -> None:
|
def append(self, message: OnceConversation) -> None:
|
||||||
"""Append the message to the record in the local file"""
|
"""Append the message to the record in the local file"""
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear session memory from the local file"""
|
"""Clear session memory from the local file"""
|
||||||
|
|
||||||
|
@@ -5,31 +5,33 @@ import datetime
|
|||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.message import OnceConversation, conversation_from_dict,conversations_to_dict
|
from pilot.scene.message import (
|
||||||
|
OnceConversation,
|
||||||
|
conversation_from_dict,
|
||||||
|
conversations_to_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class FileHistoryMemory(BaseChatHistoryMemory):
|
class FileHistoryMemory(BaseChatHistoryMemory):
|
||||||
def __init__(self, chat_session_id:str):
|
def __init__(self, chat_session_id: str):
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
date_string = now.strftime("%Y%m%d")
|
date_string = now.strftime("%Y%m%d")
|
||||||
path: str = f"{CFG.message_dir}/{date_string}"
|
path: str = f"{CFG.message_dir}/{date_string}"
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
dir_path = Path(path)
|
dir_path = Path(path)
|
||||||
self.file_path = Path(dir_path / f"{chat_session_id}.json")
|
self.file_path = Path(dir_path / f"{chat_session_id}.json")
|
||||||
if not self.file_path.exists():
|
if not self.file_path.exists():
|
||||||
self.file_path.touch()
|
self.file_path.touch()
|
||||||
self.file_path.write_text(json.dumps([]))
|
self.file_path.write_text(json.dumps([]))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def messages(self) -> List[OnceConversation]:
|
def messages(self) -> List[OnceConversation]:
|
||||||
items = json.loads(self.file_path.read_text())
|
items = json.loads(self.file_path.read_text())
|
||||||
history:List[OnceConversation] = []
|
history: List[OnceConversation] = []
|
||||||
for onece in items:
|
for onece in items:
|
||||||
messages = conversation_from_dict(onece)
|
messages = conversation_from_dict(onece)
|
||||||
history.append(messages)
|
history.append(messages)
|
||||||
@@ -38,8 +40,10 @@ class FileHistoryMemory(BaseChatHistoryMemory):
|
|||||||
def append(self, once_message: OnceConversation) -> None:
|
def append(self, once_message: OnceConversation) -> None:
|
||||||
historys = self.messages()
|
historys = self.messages()
|
||||||
historys.append(once_message)
|
historys.append(once_message)
|
||||||
self.file_path.write_text(json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4), encoding="UTF-8")
|
self.file_path.write_text(
|
||||||
|
json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4),
|
||||||
|
encoding="UTF-8",
|
||||||
|
)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.file_path.write_text(json.dumps([]))
|
self.file_path.write_text(json.dumps([]))
|
||||||
|
|
||||||
|
@@ -13,13 +13,15 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
import re
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.prompts.base import PromptValue
|
from pilot.prompts.base import PromptValue
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
class BaseOutputParser(ABC):
|
class BaseOutputParser(ABC):
|
||||||
"""Class to parse the output of an LLM call.
|
"""Class to parse the output of an LLM call.
|
||||||
@@ -41,16 +43,16 @@ class BaseOutputParser(ABC):
|
|||||||
text = text.lower()
|
text = text.lower()
|
||||||
respObj = json.loads(text)
|
respObj = json.loads(text)
|
||||||
|
|
||||||
xx = respObj['response']
|
xx = respObj["response"]
|
||||||
xx = xx.strip(b'\x00'.decode())
|
xx = xx.strip(b"\x00".decode())
|
||||||
respObj_ex = json.loads(xx)
|
respObj_ex = json.loads(xx)
|
||||||
if respObj_ex['error_code'] == 0:
|
if respObj_ex["error_code"] == 0:
|
||||||
all_text = respObj_ex['text']
|
all_text = respObj_ex["text"]
|
||||||
### 解析返回文本,获取AI回复部分
|
### 解析返回文本,获取AI回复部分
|
||||||
tmpResp = all_text.split(sep)
|
tmpResp = all_text.split(sep)
|
||||||
last_index = -1
|
last_index = -1
|
||||||
for i in range(len(tmpResp)):
|
for i in range(len(tmpResp)):
|
||||||
if tmpResp[i].find('assistant:') != -1:
|
if tmpResp[i].find("assistant:") != -1:
|
||||||
last_index = i
|
last_index = i
|
||||||
ai_response = tmpResp[last_index]
|
ai_response = tmpResp[last_index]
|
||||||
ai_response = ai_response.replace("assistant:", "")
|
ai_response = ai_response.replace("assistant:", "")
|
||||||
@@ -60,9 +62,7 @@ class BaseOutputParser(ABC):
|
|||||||
print("un_stream clear response:{}", ai_response)
|
print("un_stream clear response:{}", ai_response)
|
||||||
return ai_response
|
return ai_response
|
||||||
else:
|
else:
|
||||||
raise ValueError("Model server error!code=" + respObj_ex['error_code']);
|
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_model_server_out(self, response) -> str:
|
def parse_model_server_out(self, response) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -87,7 +87,27 @@ class BaseOutputParser(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
cleaned_output = model_out_text.rstrip()
|
||||||
|
if "```json" in cleaned_output:
|
||||||
|
_, cleaned_output = cleaned_output.split("```json")
|
||||||
|
if "```" in cleaned_output:
|
||||||
|
cleaned_output, _ = cleaned_output.split("```")
|
||||||
|
if cleaned_output.startswith("```json"):
|
||||||
|
cleaned_output = cleaned_output[len("```json"):]
|
||||||
|
if cleaned_output.startswith("```"):
|
||||||
|
cleaned_output = cleaned_output[len("```"):]
|
||||||
|
if cleaned_output.endswith("```"):
|
||||||
|
cleaned_output = cleaned_output[: -len("```")]
|
||||||
|
cleaned_output = cleaned_output.strip()
|
||||||
|
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||||
|
logger.info("illegal json processing")
|
||||||
|
json_pattern = r"{(.+?)}"
|
||||||
|
m = re.search(json_pattern, cleaned_output)
|
||||||
|
if m:
|
||||||
|
cleaned_output = m.group(0)
|
||||||
|
else:
|
||||||
|
raise ValueError("model server out not fllow the prompt!")
|
||||||
|
return cleaned_output
|
||||||
|
|
||||||
def parse_view_response(self, ai_text) -> str:
|
def parse_view_response(self, ai_text) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -98,7 +118,7 @@ class BaseOutputParser(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
return ai_text
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
"""Instructions on how the LLM output should be formatted."""
|
"""Instructions on how the LLM output should be formatted."""
|
||||||
|
@@ -1,143 +0,0 @@
|
|||||||
import platform
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import distro
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.prompts.generator import PromptGenerator
|
|
||||||
from pilot.prompts.prompt import (
|
|
||||||
DEFAULT_PROMPT_OHTER,
|
|
||||||
DEFAULT_TRIGGERING_PROMPT,
|
|
||||||
build_default_prompt_generator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoModePrompt:
|
|
||||||
""" """
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ai_goals: list | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initialize a class instance
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
ai_name (str): The name of the AI.
|
|
||||||
ai_role (str): The description of the AI's role.
|
|
||||||
ai_goals (list): The list of objectives the AI is supposed to complete.
|
|
||||||
api_budget (float): The maximum dollar value for API calls (0.0 means infinite)
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
if ai_goals is None:
|
|
||||||
ai_goals = []
|
|
||||||
self.ai_goals = ai_goals
|
|
||||||
self.prompt_generator = None
|
|
||||||
self.command_registry = None
|
|
||||||
|
|
||||||
def construct_follow_up_prompt(
|
|
||||||
self,
|
|
||||||
user_input: [str],
|
|
||||||
last_auto_return: str = None,
|
|
||||||
prompt_generator: Optional[PromptGenerator] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Build complete prompt information based on subsequent dialogue information entered by the user
|
|
||||||
Args:
|
|
||||||
self:
|
|
||||||
prompt_generator:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
prompt_start = DEFAULT_PROMPT_OHTER
|
|
||||||
if prompt_generator is None:
|
|
||||||
prompt_generator = build_default_prompt_generator()
|
|
||||||
prompt_generator.goals = user_input
|
|
||||||
prompt_generator.command_registry = self.command_registry
|
|
||||||
# 加载插件中可用命令
|
|
||||||
cfg = Config()
|
|
||||||
for plugin in cfg.plugins:
|
|
||||||
if not plugin.can_handle_post_prompt():
|
|
||||||
continue
|
|
||||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
|
||||||
|
|
||||||
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
|
|
||||||
if not self.ai_goals:
|
|
||||||
self.ai_goals = user_input
|
|
||||||
for i, goal in enumerate(self.ai_goals):
|
|
||||||
full_prompt += (
|
|
||||||
f"{i+1}.According to the provided Schema information, {goal}\n"
|
|
||||||
)
|
|
||||||
# if last_auto_return == None:
|
|
||||||
# full_prompt += f"{cfg.last_plugin_return}\n\n"
|
|
||||||
# else:
|
|
||||||
# full_prompt += f"{last_auto_return}\n\n"
|
|
||||||
|
|
||||||
full_prompt += f"Constraints:\n\n{DEFAULT_TRIGGERING_PROMPT}\n"
|
|
||||||
|
|
||||||
full_prompt += """Based on the above definition, answer the current goal and ensure that the response meets both the current constraints and the above definition and constraints"""
|
|
||||||
self.prompt_generator = prompt_generator
|
|
||||||
return full_prompt
|
|
||||||
|
|
||||||
def construct_first_prompt(
|
|
||||||
self,
|
|
||||||
fisrt_message: [str] = [],
|
|
||||||
db_schemes: str = None,
|
|
||||||
prompt_generator: Optional[PromptGenerator] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Build complete prompt information based on the initial dialogue information entered by the user
|
|
||||||
Args:
|
|
||||||
self:
|
|
||||||
prompt_generator:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
prompt_start = (
|
|
||||||
"Your decisions must always be made independently without"
|
|
||||||
" seeking user assistance. Play to your strengths as an LLM and pursue"
|
|
||||||
" simple strategies with no legal complications."
|
|
||||||
""
|
|
||||||
)
|
|
||||||
if prompt_generator is None:
|
|
||||||
prompt_generator = build_default_prompt_generator()
|
|
||||||
prompt_generator.goals = fisrt_message
|
|
||||||
prompt_generator.command_registry = self.command_registry
|
|
||||||
# 加载插件中可用命令
|
|
||||||
cfg = Config()
|
|
||||||
for plugin in cfg.plugins:
|
|
||||||
if not plugin.can_handle_post_prompt():
|
|
||||||
continue
|
|
||||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
|
||||||
if cfg.execute_local_commands:
|
|
||||||
# add OS info to prompt
|
|
||||||
os_name = platform.system()
|
|
||||||
os_info = (
|
|
||||||
platform.platform(terse=True)
|
|
||||||
if os_name != "Linux"
|
|
||||||
else distro.name(pretty=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_start += f"\nThe OS you are running on is: {os_info}"
|
|
||||||
|
|
||||||
# Construct full prompt
|
|
||||||
full_prompt = f"{prompt_start}\n\nGOALS:\n\n"
|
|
||||||
if not self.ai_goals:
|
|
||||||
self.ai_goals = fisrt_message
|
|
||||||
for i, goal in enumerate(self.ai_goals):
|
|
||||||
full_prompt += (
|
|
||||||
f"{i+1}.According to the provided Schema information,{goal}\n"
|
|
||||||
)
|
|
||||||
if db_schemes:
|
|
||||||
full_prompt += f"\nSchema:\n\n"
|
|
||||||
full_prompt += f"{db_schemes}"
|
|
||||||
|
|
||||||
# if self.api_budget > 0.0:
|
|
||||||
# full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}"
|
|
||||||
self.prompt_generator = prompt_generator
|
|
||||||
full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}"
|
|
||||||
return full_prompt
|
|
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -8,7 +6,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
from pilot.scene.base_message import BaseMessage,HumanMessage,AIMessage, SystemMessage
|
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_string(
|
def get_buffer_string(
|
||||||
@@ -29,7 +27,6 @@ def get_buffer_string(
|
|||||||
return "\n".join(string_messages)
|
return "\n".join(string_messages)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(BaseModel, ABC):
|
class PromptValue(BaseModel, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
@@ -39,6 +36,7 @@ class PromptValue(BaseModel, ABC):
|
|||||||
def to_messages(self) -> List[BaseMessage]:
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
"""Return prompt as messages."""
|
"""Return prompt as messages."""
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValue(PromptValue):
|
class ChatPromptValue(PromptValue):
|
||||||
messages: List[BaseMessage]
|
messages: List[BaseMessage]
|
||||||
|
|
||||||
@@ -48,4 +46,4 @@ class ChatPromptValue(PromptValue):
|
|||||||
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
"""Return prompt as messages."""
|
"""Return prompt as messages."""
|
||||||
return self.messages
|
return self.messages
|
||||||
|
@@ -3,7 +3,7 @@ import json
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class PromptGenerator:
|
class PluginPromptGenerator:
|
||||||
"""
|
"""
|
||||||
A class for generating custom prompt strings based on constraints, commands,
|
A class for generating custom prompt strings based on constraints, commands,
|
||||||
resources, and performance evaluations.
|
resources, and performance evaluations.
|
||||||
@@ -133,6 +133,11 @@ class PromptGenerator:
|
|||||||
else:
|
else:
|
||||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_commands_string(self)->str:
|
||||||
|
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt_string(self) -> str:
|
def generate_prompt_string(self) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a prompt string based on the constraints, commands, resources,
|
Generate a prompt string based on the constraints, commands, resources,
|
||||||
|
@@ -1,73 +0,0 @@
|
|||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.prompts.generator import PromptGenerator
|
|
||||||
|
|
||||||
CFG = Config()
|
|
||||||
|
|
||||||
DEFAULT_TRIGGERING_PROMPT = (
|
|
||||||
"Determine which next command to use, and respond using the format specified above"
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_PROMPT_OHTER = "Previous response was excellent. Please response according to the requirements based on the new goal"
|
|
||||||
|
|
||||||
|
|
||||||
def build_default_prompt_generator() -> PromptGenerator:
|
|
||||||
"""
|
|
||||||
This function generates a prompt string that includes various constraints,
|
|
||||||
commands, resources, and performance evaluations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The generated prompt string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Initialize the PromptGenerator object
|
|
||||||
prompt_generator = PromptGenerator()
|
|
||||||
|
|
||||||
# Add constraints to the PromptGenerator object
|
|
||||||
# prompt_generator.add_constraint(
|
|
||||||
# "~4000 word limit for short term memory. Your short term memory is short, so"
|
|
||||||
# " immediately save important information to files."
|
|
||||||
# )
|
|
||||||
prompt_generator.add_constraint(
|
|
||||||
"If you are unsure how you previously did something or want to recall past"
|
|
||||||
" events, thinking about similar events will help you remember."
|
|
||||||
)
|
|
||||||
# prompt_generator.add_constraint("No user assistance")
|
|
||||||
|
|
||||||
prompt_generator.add_constraint("Only output one correct JSON response at a time")
|
|
||||||
prompt_generator.add_constraint(
|
|
||||||
'Exclusively use the commands listed in double quotes e.g. "command name"'
|
|
||||||
)
|
|
||||||
prompt_generator.add_constraint(
|
|
||||||
"If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition"
|
|
||||||
)
|
|
||||||
prompt_generator.add_constraint(
|
|
||||||
"The generated command args need to comply with the definition of the command"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add resources to the PromptGenerator object
|
|
||||||
# prompt_generator.add_resource(
|
|
||||||
# "Internet access for searches and information gathering."
|
|
||||||
# )
|
|
||||||
# prompt_generator.add_resource("Long Term memory management.")
|
|
||||||
# prompt_generator.add_resource(
|
|
||||||
# "DB-GPT powered Agents for delegation of simple tasks."
|
|
||||||
# )
|
|
||||||
# prompt_generator.add_resource("File output.")
|
|
||||||
|
|
||||||
# Add performance evaluations to the PromptGenerator object
|
|
||||||
prompt_generator.add_performance_evaluation(
|
|
||||||
"Continuously review and analyze your actions to ensure you are performing to"
|
|
||||||
" the best of your abilities."
|
|
||||||
)
|
|
||||||
prompt_generator.add_performance_evaluation(
|
|
||||||
"Constructively self-criticize your big-picture behavior constantly."
|
|
||||||
)
|
|
||||||
prompt_generator.add_performance_evaluation(
|
|
||||||
"Reflect on past decisions and strategies to refine your approach."
|
|
||||||
)
|
|
||||||
# prompt_generator.add_performance_evaluation(
|
|
||||||
# "Every command has a cost, so be smart and efficient. Aim to complete tasks in"
|
|
||||||
# " the least number of steps."
|
|
||||||
# )
|
|
||||||
# prompt_generator.add_performance_evaluation("Write all code to a file.")
|
|
||||||
return prompt_generator
|
|
@@ -3,8 +3,8 @@ from typing import Any, Callable, Dict, List, Optional
|
|||||||
|
|
||||||
class PromptGenerator:
|
class PromptGenerator:
|
||||||
"""
|
"""
|
||||||
generating custom prompt strings based on constraints;
|
generating custom prompt strings based on constraints;
|
||||||
Compatible with AutoGpt Plugin;
|
Compatible with AutoGpt Plugin;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -22,8 +22,6 @@ class PromptGenerator:
|
|||||||
self.role = "AI"
|
self.role = "AI"
|
||||||
self.response_format = None
|
self.response_format = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_command(
|
def add_command(
|
||||||
self,
|
self,
|
||||||
command_label: str,
|
command_label: str,
|
||||||
@@ -51,4 +49,4 @@ class PromptGenerator:
|
|||||||
"args": command_args,
|
"args": command_args,
|
||||||
"function": function,
|
"function": function,
|
||||||
}
|
}
|
||||||
self.commands.append(command)
|
self.commands.append(command)
|
||||||
|
@@ -8,6 +8,7 @@ from pilot.common.formatting import formatter
|
|||||||
from pilot.out_parser.base import BaseOutputParser
|
from pilot.out_parser.base import BaseOutputParser
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
"""Format a template using jinja2."""
|
"""Format a template using jinja2."""
|
||||||
try:
|
try:
|
||||||
@@ -32,22 +33,23 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
template_scene: str
|
template_scene: str
|
||||||
|
|
||||||
template_define:str
|
template_define: str
|
||||||
"""this template define"""
|
"""this template define"""
|
||||||
template: str
|
template: str
|
||||||
"""The prompt template."""
|
"""The prompt template."""
|
||||||
template_format: str = "f-string"
|
template_format: str = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
response_format:str
|
response_format: str
|
||||||
"""default use stream out"""
|
"""default use stream out"""
|
||||||
stream_out: bool = True
|
stream_out: bool = True
|
||||||
""""""
|
""""""
|
||||||
output_parser: BaseOutputParser = None
|
output_parser: BaseOutputParser = None
|
||||||
""""""
|
""""""
|
||||||
sep:str = SeparatorStyle.SINGLE.value
|
sep: str = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -96,10 +98,8 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
else:
|
else:
|
||||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||||
|
|
||||||
|
def format(self, **kwargs: Any) -> str:
|
||||||
def format(self, **kwargs: Any) -> str:
|
"""Format the prompt with the inputs."""
|
||||||
"""Format the prompt with the inputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||||
|
@@ -207,6 +207,7 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
|
||||||
class StringPromptValue(PromptValue):
|
class StringPromptValue(PromptValue):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
@@ -219,7 +220,6 @@ class StringPromptValue(PromptValue):
|
|||||||
return [HumanMessage(content=self.text)]
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||||
"""String prompt should expose the format method, returning a prompt."""
|
"""String prompt should expose the format method, returning a prompt."""
|
||||||
|
|
||||||
@@ -360,4 +360,4 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
|
|
||||||
|
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
Prompt = PromptTemplate
|
Prompt = PromptTemplate
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class ChatScene(Enum):
|
class ChatScene(Enum):
|
||||||
ChatWithDb = "chat_with_db"
|
ChatWithDb = "chat_with_db"
|
||||||
ChatExecution = "chat_execution"
|
ChatExecution = "chat_execution"
|
||||||
ChatKnowledge = "chat_default_knowledge"
|
ChatKnowledge = "chat_default_knowledge"
|
||||||
ChatNewKnowledge = "chat_new_knowledge"
|
ChatNewKnowledge = "chat_new_knowledge"
|
||||||
ChatNormal = "chat_normal"
|
ChatNormal = "chat_normal"
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import datetime
|
||||||
|
import traceback
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@@ -20,20 +22,27 @@ from pilot.prompts.prompt_new import PromptTemplate
|
|||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||||
from pilot.utils import (
|
from pilot.utils import (
|
||||||
build_logger,
|
build_logger,
|
||||||
server_error_msg,
|
server_error_msg,
|
||||||
)
|
)
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.scene.base_message import (
|
||||||
from pilot.scene.base import ChatScene
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
HumanMessage,
|
||||||
|
AIMessage,
|
||||||
|
ViewMessage,
|
||||||
|
)
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
class BaseChat( ABC):
|
|
||||||
chat_scene:str = None
|
|
||||||
|
class BaseChat(ABC):
|
||||||
|
chat_scene: str = None
|
||||||
llm_model: Any = None
|
llm_model: Any = None
|
||||||
temperature: float = 0.6
|
temperature: float = 0.6
|
||||||
max_new_tokens: int = 1024
|
max_new_tokens: int = 1024
|
||||||
@@ -42,17 +51,20 @@ class BaseChat( ABC):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(self, chat_mode, chat_session_id, current_user_input):
|
def __init__(self, chat_mode, chat_session_id, current_user_input):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_session_id
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
self.current_user_input:str = current_user_input
|
self.current_user_input: str = current_user_input
|
||||||
self.llm_model = CFG.LLM_MODEL
|
self.llm_model = CFG.LLM_MODEL
|
||||||
### TODO
|
### TODO
|
||||||
self.memory = FileHistoryMemory(chat_session_id)
|
self.memory = FileHistoryMemory(chat_session_id)
|
||||||
### load prompt template
|
### load prompt template
|
||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
|
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
|
self.chat_mode.value
|
||||||
|
]
|
||||||
self.history_message: List[OnceConversation] = []
|
self.history_message: List[OnceConversation] = []
|
||||||
self.current_message: OnceConversation = OnceConversation()
|
self.current_message: OnceConversation = OnceConversation()
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
@@ -69,15 +81,163 @@ class BaseChat( ABC):
|
|||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
raise NotImplementedError("Not supported for this chat type.")
|
raise NotImplementedError("Not supported for this chat type.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_input_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_with_prompt_response(self, prompt_response):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def call(self):
|
def call(self):
|
||||||
pass
|
input_values = self.generate_input_values()
|
||||||
|
|
||||||
|
### Chat sequence advance
|
||||||
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
|
self.current_message.start_date = datetime.datetime.now()
|
||||||
|
# TODO
|
||||||
|
self.current_message.tokens = 0
|
||||||
|
|
||||||
|
current_prompt = self.prompt_template.format(**input_values)
|
||||||
|
|
||||||
|
### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||||
|
if self.history_message:
|
||||||
|
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||||
|
logger.info(
|
||||||
|
f"There are already {len(self.history_message)} rounds of conversations!"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_message.add_system_message(current_prompt)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.llm_model,
|
||||||
|
"prompt": self.generate_llm_text(),
|
||||||
|
"temperature": float(self.temperature),
|
||||||
|
"max_new_tokens": int(self.max_new_tokens),
|
||||||
|
"stop": self.prompt_template.sep,
|
||||||
|
}
|
||||||
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
ai_response_text = ""
|
||||||
|
try:
|
||||||
|
if not self.prompt_template.stream_out:
|
||||||
|
### 走非流式的模型服务接口
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
### output parse
|
||||||
|
ai_response_text = (
|
||||||
|
self.prompt_template.output_parser.parse_model_server_out(response)
|
||||||
|
)
|
||||||
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
|
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
|
|
||||||
|
result = self.do_with_prompt_response(prompt_define_response)
|
||||||
|
|
||||||
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
|
if prompt_define_response.thoughts.get("speak"):
|
||||||
|
self.current_message.add_view_message(
|
||||||
|
self.prompt_template.output_parser.parse_view_response(
|
||||||
|
prompt_define_response.thoughts.get("speak"), result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif prompt_define_response.thoughts.get("reasoning"):
|
||||||
|
self.current_message.add_view_message(
|
||||||
|
self.prompt_template.output_parser.parse_view_response(
|
||||||
|
prompt_define_response.thoughts.get("reasoning"), result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.current_message.add_view_message(
|
||||||
|
self.prompt_template.output_parser.parse_view_response(
|
||||||
|
prompt_define_response.thoughts, result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.current_message.add_view_message(
|
||||||
|
self.prompt_template.output_parser.parse_view_response(
|
||||||
|
prompt_define_response, result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
#TODO
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
logger.error("model response parase faild!" + str(e))
|
||||||
|
self.current_message.add_view_message(
|
||||||
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
|
)
|
||||||
|
### 对话记录存储
|
||||||
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
|
def generate_llm_text(self) -> str:
|
||||||
|
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||||
|
### 线处理历史信息
|
||||||
|
if len(self.history_message) > self.chat_retention_rounds:
|
||||||
|
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||||
|
for first_message in self.history_message[0].messages:
|
||||||
|
if not isinstance(first_message, ViewMessage):
|
||||||
|
text += (
|
||||||
|
first_message.type
|
||||||
|
+ ":"
|
||||||
|
+ first_message.content
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
index = self.chat_retention_rounds - 1
|
||||||
|
for last_message in self.history_message[-index:].messages:
|
||||||
|
if not isinstance(last_message, ViewMessage):
|
||||||
|
text += (
|
||||||
|
last_message.type
|
||||||
|
+ ":"
|
||||||
|
+ last_message.content
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
### 直接历史记录拼接
|
||||||
|
for conversation in self.history_message:
|
||||||
|
for message in conversation.messages:
|
||||||
|
if not isinstance(message, ViewMessage):
|
||||||
|
text += (
|
||||||
|
message.type
|
||||||
|
+ ":"
|
||||||
|
+ message.content
|
||||||
|
+ self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
### current conversation
|
||||||
|
for now_message in self.current_message.messages:
|
||||||
|
text += (
|
||||||
|
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||||
|
)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def chat_show(self):
|
def chat_show(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 暂时为了兼容前端
|
||||||
def current_ai_response(self) -> str:
|
def current_ai_response(self) -> str:
|
||||||
pass
|
for message in self.current_message.messages:
|
||||||
|
if message.type == "view":
|
||||||
|
return message.content
|
||||||
|
return None
|
||||||
|
|
||||||
def _load_history(self, session_id: str) -> List[OnceConversation]:
|
def _load_history(self, session_id: str) -> List[OnceConversation]:
|
||||||
"""
|
"""
|
||||||
@@ -88,7 +248,7 @@ class BaseChat( ABC):
|
|||||||
"""
|
"""
|
||||||
return self.memory.messages()
|
return self.memory.messages()
|
||||||
|
|
||||||
def generate(self, p)->str:
|
def generate(self, p) -> str:
|
||||||
"""
|
"""
|
||||||
generate context for LLM input
|
generate context for LLM input
|
||||||
Args:
|
Args:
|
||||||
|
@@ -15,6 +15,7 @@ from typing import (
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(BaseModel, ABC):
|
class PromptValue(BaseModel, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
@@ -37,7 +38,6 @@ class BaseMessage(BaseModel):
|
|||||||
"""Type of the message, used for serialization."""
|
"""Type of the message, used for serialization."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the human."""
|
"""Type of message that is spoken by the human."""
|
||||||
|
|
||||||
@@ -49,7 +49,6 @@ class HumanMessage(BaseMessage):
|
|||||||
return "human"
|
return "human"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
class AIMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the AI."""
|
"""Type of message that is spoken by the AI."""
|
||||||
|
|
||||||
@@ -81,8 +80,6 @@ class SystemMessage(BaseMessage):
|
|||||||
return "system"
|
return "system"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
"""Output of a single generation."""
|
"""Output of a single generation."""
|
||||||
|
|
||||||
@@ -94,7 +91,6 @@ class Generation(BaseModel):
|
|||||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
class ChatGeneration(Generation):
|
||||||
"""Output of a single generation."""
|
"""Output of a single generation."""
|
||||||
|
|
||||||
@@ -126,7 +122,6 @@ class LLMResult(BaseModel):
|
|||||||
"""For arbitrary LLM provider specific output."""
|
"""For arbitrary LLM provider specific output."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
def _message_to_dict(message: BaseMessage) -> dict:
|
||||||
return {"type": message.type, "data": message.dict()}
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
@@ -149,6 +144,5 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|||||||
raise ValueError(f"Got unexpected type: {_type}")
|
raise ValueError(f"Got unexpected type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||||
return [_message_from_dict(m) for m in messages]
|
return [_message_from_dict(m) for m in messages]
|
||||||
|
@@ -14,7 +14,13 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage
|
from pilot.scene.base_message import (
|
||||||
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
HumanMessage,
|
||||||
|
AIMessage,
|
||||||
|
ViewMessage,
|
||||||
|
)
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@@ -25,22 +31,24 @@ from pilot.utils import (
|
|||||||
build_logger,
|
build_logger,
|
||||||
server_error_msg,
|
server_error_msg,
|
||||||
)
|
)
|
||||||
from pilot.common.markdown_text import generate_markdown_table, generate_htm_table, datas_to_table_html
|
from pilot.common.markdown_text import (
|
||||||
|
generate_markdown_table,
|
||||||
|
generate_htm_table,
|
||||||
|
datas_to_table_html,
|
||||||
|
)
|
||||||
from pilot.scene.chat_db.prompt import chat_db_prompt
|
from pilot.scene.chat_db.prompt import chat_db_prompt
|
||||||
from pilot.out_parser.base import BaseOutputParser
|
from pilot.out_parser.base import BaseOutputParser
|
||||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class ChatWithDb(BaseChat):
|
class ChatWithDb(BaseChat):
|
||||||
chat_scene: str = ChatScene.ChatWithDb.value
|
chat_scene: str = ChatScene.ChatWithDb.value
|
||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, chat_session_id, db_name, user_input):
|
def __init__(self, chat_session_id, db_name, user_input):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
|
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
|
||||||
if not db_name:
|
if not db_name:
|
||||||
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
|
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
|
||||||
@@ -50,118 +58,126 @@ class ChatWithDb(BaseChat):
|
|||||||
self.db_connect = self.database.get_session(self.db_name)
|
self.db_connect = self.database.get_session(self.db_name)
|
||||||
self.top_k: int = 5
|
self.top_k: int = 5
|
||||||
|
|
||||||
def call(self) -> str:
|
def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
"top_k": str(self.top_k),
|
"top_k": str(self.top_k),
|
||||||
"dialect": self.database.dialect,
|
"dialect": self.database.dialect,
|
||||||
"table_info": self.database.table_simple_info(self.db_connect),
|
"table_info": self.database.table_simple_info(self.db_connect)
|
||||||
# "stop": self.sep_style,
|
}
|
||||||
}
|
return input_values
|
||||||
|
|
||||||
### Chat sequence advance
|
def do_with_prompt_response(self, prompt_response):
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
return self.database.run(self.db_connect, prompt_response.sql)
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
|
||||||
self.current_message.start_date = datetime.datetime.now()
|
|
||||||
# TODO
|
|
||||||
self.current_message.tokens = 0
|
|
||||||
|
|
||||||
current_prompt = self.prompt_template.format(**input_values)
|
|
||||||
|
|
||||||
### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
|
||||||
if self.history_message:
|
|
||||||
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
|
||||||
logger.info(f"There are already {len(self.history_message)} rounds of conversations!")
|
|
||||||
|
|
||||||
self.current_message.add_system_message(current_prompt)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": self.llm_model,
|
|
||||||
"prompt": self.generate_llm_text(),
|
|
||||||
"temperature": float(self.temperature),
|
|
||||||
"max_new_tokens": int(self.max_new_tokens),
|
|
||||||
"stop": self.prompt_template.sep,
|
|
||||||
}
|
|
||||||
logger.info(f"Requert: \n{payload}")
|
|
||||||
ai_response_text = ""
|
|
||||||
try:
|
|
||||||
### 走非流式的模型服务接口
|
|
||||||
|
|
||||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120)
|
|
||||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
|
||||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
|
||||||
|
|
||||||
result = self.database.run(self.db_connect, prompt_define_response.sql)
|
|
||||||
|
|
||||||
|
|
||||||
if hasattr(prompt_define_response, 'thoughts'):
|
# def call(self) -> str:
|
||||||
if prompt_define_response.thoughts.get("speak"):
|
# input_values = {
|
||||||
self.current_message.add_view_message(
|
# "input": self.current_user_input,
|
||||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("speak"),result))
|
# "top_k": str(self.top_k),
|
||||||
elif prompt_define_response.thoughts.get("reasoning"):
|
# "dialect": self.database.dialect,
|
||||||
self.current_message.add_view_message(
|
# "table_info": self.database.table_simple_info(self.db_connect),
|
||||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("reasoning"), result))
|
# # "stop": self.sep_style,
|
||||||
else:
|
# }
|
||||||
self.current_message.add_view_message(
|
#
|
||||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts, result))
|
# ### Chat sequence advance
|
||||||
else:
|
# self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_view_message(
|
# self.current_message.add_user_message(self.current_user_input)
|
||||||
self.prompt_template.output_parser.parse_view_response(prompt_define_response, result))
|
# self.current_message.start_date = datetime.datetime.now()
|
||||||
|
# # TODO
|
||||||
except Exception as e:
|
# self.current_message.tokens = 0
|
||||||
print(traceback.format_exc())
|
#
|
||||||
logger.error("model response parase faild!" + str(e))
|
# current_prompt = self.prompt_template.format(**input_values)
|
||||||
self.current_message.add_view_message(f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """)
|
#
|
||||||
### 对话记录存储
|
# ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||||
self.memory.append(self.current_message)
|
# if self.history_message:
|
||||||
|
# ## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||||
|
# logger.info(
|
||||||
|
# f"There are already {len(self.history_message)} rounds of conversations!"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# self.current_message.add_system_message(current_prompt)
|
||||||
|
#
|
||||||
|
# payload = {
|
||||||
|
# "model": self.llm_model,
|
||||||
|
# "prompt": self.generate_llm_text(),
|
||||||
|
# "temperature": float(self.temperature),
|
||||||
|
# "max_new_tokens": int(self.max_new_tokens),
|
||||||
|
# "stop": self.prompt_template.sep,
|
||||||
|
# }
|
||||||
|
# logger.info(f"Requert: \n{payload}")
|
||||||
|
# ai_response_text = ""
|
||||||
|
# try:
|
||||||
|
# ### 走非流式的模型服务接口
|
||||||
|
#
|
||||||
|
# response = requests.post(
|
||||||
|
# urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
|
# headers=headers,
|
||||||
|
# json=payload,
|
||||||
|
# timeout=120,
|
||||||
|
# )
|
||||||
|
# ai_response_text = (
|
||||||
|
# self.prompt_template.output_parser.parse_model_server_out(response)
|
||||||
|
# )
|
||||||
|
# self.current_message.add_ai_message(ai_response_text)
|
||||||
|
# prompt_define_response = (
|
||||||
|
# self.prompt_template.output_parser.parse_prompt_response(
|
||||||
|
# ai_response_text
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# result = self.database.run(self.db_connect, prompt_define_response.sql)
|
||||||
|
#
|
||||||
|
# if hasattr(prompt_define_response, "thoughts"):
|
||||||
|
# if prompt_define_response.thoughts.get("speak"):
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response.thoughts.get("speak"), result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# elif prompt_define_response.thoughts.get("reasoning"):
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response.thoughts.get("reasoning"), result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response.thoughts, result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response, result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# print(traceback.format_exc())
|
||||||
|
# logger.error("model response parase faild!" + str(e))
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
|
# )
|
||||||
|
# ### 对话记录存储
|
||||||
|
# self.memory.append(self.current_message)
|
||||||
|
|
||||||
def chat_show(self):
|
def chat_show(self):
|
||||||
ret = []
|
ret = []
|
||||||
# 单论对话只能有一次User 记录 和一次 AI 记录
|
# 单论对话只能有一次User 记录 和一次 AI 记录
|
||||||
# TODO 推理过程前端展示。。。
|
# TODO 推理过程前端展示。。。
|
||||||
for message in self.current_message.messages:
|
for message in self.current_message.messages:
|
||||||
if (isinstance(message, HumanMessage)):
|
if isinstance(message, HumanMessage):
|
||||||
ret[-1][-2] = message.content
|
ret[-1][-2] = message.content
|
||||||
# 是否展示推理过程
|
# 是否展示推理过程
|
||||||
if (isinstance(message, ViewMessage)):
|
if isinstance(message, ViewMessage):
|
||||||
ret[-1][-1] = message.content
|
ret[-1][-1] = message.content
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# 暂时为了兼容前端
|
|
||||||
def current_ai_response(self) -> str:
|
|
||||||
for message in self.current_message.messages:
|
|
||||||
if message.type == 'view':
|
|
||||||
return message.content
|
|
||||||
return None
|
|
||||||
|
|
||||||
def generate_llm_text(self) -> str:
|
|
||||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
|
||||||
### 线处理历史信息
|
|
||||||
if (len(self.history_message) > self.chat_retention_rounds):
|
|
||||||
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
|
||||||
for first_message in self.history_message[0].messages:
|
|
||||||
if not isinstance(first_message, ViewMessage):
|
|
||||||
text += first_message.type + ":" + first_message.content + self.prompt_template.sep
|
|
||||||
|
|
||||||
index = self.chat_retention_rounds - 1
|
|
||||||
for last_message in self.history_message[-index:].messages:
|
|
||||||
if not isinstance(last_message, ViewMessage):
|
|
||||||
text += last_message.type + ":" + last_message.content + self.prompt_template.sep
|
|
||||||
|
|
||||||
else:
|
|
||||||
### 直接历史记录拼接
|
|
||||||
for conversation in self.history_message:
|
|
||||||
for message in conversation.messages:
|
|
||||||
if not isinstance(message, ViewMessage):
|
|
||||||
text += message.type + ":" + message.content + self.prompt_template.sep
|
|
||||||
|
|
||||||
### current conversation
|
|
||||||
for now_message in self.current_message.messages:
|
|
||||||
text += now_message.type + ":" + now_message.content + self.prompt_template.sep
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
|
@@ -1,52 +1,28 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import Dict, NamedTuple
|
||||||
Dict,
|
|
||||||
NamedTuple
|
|
||||||
)
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
|
||||||
class SqlAction(NamedTuple):
|
class SqlAction(NamedTuple):
|
||||||
sql: str
|
sql: str
|
||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
|
||||||
class DbChatOutputParser(BaseOutputParser):
|
class DbChatOutputParser(BaseOutputParser):
|
||||||
|
def __init__(self, sep: str, is_stream_out: bool):
|
||||||
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||||
|
|
||||||
def __init__(self, sep:str, is_stream_out: bool):
|
|
||||||
super().__init__(sep=sep, is_stream_out=is_stream_out )
|
|
||||||
|
|
||||||
|
|
||||||
def parse_model_server_out(self, response) -> str:
|
|
||||||
return super().parse_model_server_out(response)
|
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text):
|
def parse_prompt_response(self, model_out_text):
|
||||||
cleaned_output = model_out_text.rstrip()
|
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||||
if "```json" in cleaned_output:
|
|
||||||
_, cleaned_output = cleaned_output.split("```json")
|
|
||||||
if "```" in cleaned_output:
|
|
||||||
cleaned_output, _ = cleaned_output.split("```")
|
|
||||||
if cleaned_output.startswith("```json"):
|
|
||||||
cleaned_output = cleaned_output[len("```json"):]
|
|
||||||
if cleaned_output.startswith("```"):
|
|
||||||
cleaned_output = cleaned_output[len("```"):]
|
|
||||||
if cleaned_output.endswith("```"):
|
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
|
||||||
cleaned_output = cleaned_output.strip()
|
|
||||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
|
||||||
logger.info("illegal json processing")
|
|
||||||
json_pattern = r'{(.+?)}'
|
|
||||||
m = re.search(json_pattern, cleaned_output)
|
|
||||||
if m:
|
|
||||||
cleaned_output = m.group(0)
|
|
||||||
else:
|
|
||||||
raise ValueError("model server out not fllow the prompt!")
|
|
||||||
|
|
||||||
response = json.loads(cleaned_output)
|
|
||||||
sql, thoughts = response["sql"], response["thoughts"]
|
sql, thoughts = response["sql"], response["thoughts"]
|
||||||
return SqlAction(sql, thoughts)
|
return SqlAction(sql, thoughts)
|
||||||
|
|
||||||
|
@@ -45,9 +45,15 @@ RESPONSE_FORMAT = {
|
|||||||
"reasoning": "reasoning",
|
"reasoning": "reasoning",
|
||||||
"speak": "thoughts summary to say to user",
|
"speak": "thoughts summary to say to user",
|
||||||
},
|
},
|
||||||
"sql": "SQL Query to run"
|
"sql": "SQL Query to run",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
|
"thoughts": "thoughts summary to say to user",
|
||||||
|
"sql": "SQL Query to run",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
@@ -55,11 +61,13 @@ PROMPT_NEED_NEED_STREAM_OUT = False
|
|||||||
chat_db_prompt = PromptTemplate(
|
chat_db_prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatWithDb.value,
|
template_scene=ChatScene.ChatWithDb.value,
|
||||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
|
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
output_parser=DbChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})
|
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})
|
||||||
|
@@ -1,26 +1,156 @@
|
|||||||
|
import requests
|
||||||
|
import datetime
|
||||||
|
from urllib.parse import urljoin
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import traceback
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.commands.command import execute_command
|
||||||
|
from pilot.prompts.generator import PluginPromptGenerator
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
class ChatWithPlugin(BaseChat):
|
class ChatWithPlugin(BaseChat):
|
||||||
chat_scene: str= ChatScene.ChatExecution.value
|
chat_scene: str = ChatScene.ChatExecution.value
|
||||||
def __init__(self, chat_mode, chat_session_id, current_user_input):
|
plugins_prompt_generator:PluginPromptGenerator
|
||||||
super().__init__(chat_mode, chat_session_id, current_user_input)
|
select_plugin: str = None
|
||||||
|
|
||||||
def call(self):
|
def __init__(self, chat_mode, chat_session_id, current_user_input, select_plugin:str=None):
|
||||||
super().call()
|
super().__init__(chat_mode, chat_session_id, current_user_input)
|
||||||
|
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||||
|
self.plugins_prompt_generator.command_registry = self.command_registry
|
||||||
|
# 加载插件中可用命令
|
||||||
|
self.select_plugin = select_plugin
|
||||||
|
if self.select_plugin:
|
||||||
|
for plugin in CFG.plugins:
|
||||||
|
if plugin.
|
||||||
|
else:
|
||||||
|
for plugin in CFG.plugins:
|
||||||
|
if not plugin.can_handle_post_prompt():
|
||||||
|
continue
|
||||||
|
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
input_values = {
|
||||||
|
"input": self.current_user_input,
|
||||||
|
"constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints),
|
||||||
|
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
|
||||||
|
}
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
def do_with_prompt_response(self, prompt_response):
|
||||||
|
## plugin command run
|
||||||
|
return execute_command(str(prompt_response), self.plugins_prompt_generator)
|
||||||
|
|
||||||
|
|
||||||
|
# def call(self):
|
||||||
|
# input_values = {
|
||||||
|
# "input": self.current_user_input,
|
||||||
|
# "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints),
|
||||||
|
# "commands_infos": self.__get_comnands_promp_info()
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# ### Chat sequence advance
|
||||||
|
# self.current_message.chat_order = len(self.history_message) + 1
|
||||||
|
# self.current_message.add_user_message(self.current_user_input)
|
||||||
|
# self.current_message.start_date = datetime.datetime.now()
|
||||||
|
# # TODO
|
||||||
|
# self.current_message.tokens = 0
|
||||||
|
#
|
||||||
|
# current_prompt = self.prompt_template.format(**input_values)
|
||||||
|
#
|
||||||
|
# ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||||
|
# if self.history_message:
|
||||||
|
# ## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||||
|
# logger.info(
|
||||||
|
# f"There are already {len(self.history_message)} rounds of conversations!"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# self.current_message.add_system_message(current_prompt)
|
||||||
|
#
|
||||||
|
# payload = {
|
||||||
|
# "model": self.llm_model,
|
||||||
|
# "prompt": self.generate_llm_text(),
|
||||||
|
# "temperature": float(self.temperature),
|
||||||
|
# "max_new_tokens": int(self.max_new_tokens),
|
||||||
|
# "stop": self.prompt_template.sep,
|
||||||
|
# }
|
||||||
|
# logger.info(f"Requert: \n{payload}")
|
||||||
|
# ai_response_text = ""
|
||||||
|
# try:
|
||||||
|
# ### 走非流式的模型服务接口
|
||||||
|
#
|
||||||
|
# response = requests.post(
|
||||||
|
# urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
|
# headers=headers,
|
||||||
|
# json=payload,
|
||||||
|
# timeout=120,
|
||||||
|
# )
|
||||||
|
# ai_response_text = (
|
||||||
|
# self.prompt_template.output_parser.parse_model_server_out(response)
|
||||||
|
# )
|
||||||
|
# self.current_message.add_ai_message(ai_response_text)
|
||||||
|
# prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ## plugin command run
|
||||||
|
# result = execute_command(prompt_define_response, self.plugins_prompt_generator)
|
||||||
|
#
|
||||||
|
# if hasattr(prompt_define_response, "thoughts"):
|
||||||
|
# if prompt_define_response.thoughts.get("speak"):
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response.thoughts.get("speak"), result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# elif prompt_define_response.thoughts.get("reasoning"):
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response.thoughts.get("reasoning"), result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response.thoughts, result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# self.prompt_template.output_parser.parse_view_response(
|
||||||
|
# prompt_define_response, result
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# print(traceback.format_exc())
|
||||||
|
# logger.error("model response parase faild!" + str(e))
|
||||||
|
# self.current_message.add_view_message(
|
||||||
|
# f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
|
# )
|
||||||
|
# ### 对话记录存储
|
||||||
|
# self.memory.append(self.current_message)
|
||||||
|
|
||||||
def chat_show(self):
|
def chat_show(self):
|
||||||
super().chat_show()
|
super().chat_show()
|
||||||
|
|
||||||
def _load_history(self, session_id: str) -> List[OnceConversation]:
|
|
||||||
return super()._load_history(session_id)
|
def __list_to_prompt_str(list: List) -> str:
|
||||||
|
if not list:
|
||||||
|
separator = '\n'
|
||||||
|
return separator.join(list)
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
def generate(self, p) -> str:
|
def generate(self, p) -> str:
|
||||||
return super().generate(p)
|
return super().generate(p)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatExecution.value
|
return ChatScene.ChatExecution.value
|
||||||
|
30
pilot/scene/chat_execution/out_parser.py
Normal file
30
pilot/scene/chat_execution/out_parser.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, NamedTuple
|
||||||
|
import pandas as pd
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
class PluginAction(NamedTuple):
|
||||||
|
command: Dict
|
||||||
|
thoughts: Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PluginChatOutputParser(BaseOutputParser):
|
||||||
|
|
||||||
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
|
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||||
|
sql, thoughts = response["command"], response["thoughts"]
|
||||||
|
return PluginAction(sql, thoughts)
|
||||||
|
|
||||||
|
def parse_view_response(self, ai_text) -> str:
|
||||||
|
return super().parse_view_response(ai_text)
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
pass
|
65
pilot/scene/chat_execution/prompt.py
Normal file
65
pilot/scene/chat_execution/prompt.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import json
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
|
||||||
|
|
||||||
|
PROMPT_SUFFIX = """
|
||||||
|
Goals:
|
||||||
|
{input}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
Constraints:
|
||||||
|
Exclusively use the commands listed in double quotes e.g. "command name"
|
||||||
|
Reflect on past decisions and strategies to refine your approach.
|
||||||
|
Constructively self-criticize your big-picture behavior constantly.
|
||||||
|
{constraints}
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
{commands_infos}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||||
|
{response}
|
||||||
|
|
||||||
|
Ensure the response is correct json and can be parsed by Python json.loads
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = {
|
||||||
|
"thoughts": {
|
||||||
|
"text": "thought",
|
||||||
|
"reasoning": "reasoning",
|
||||||
|
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||||
|
"criticism": "constructive self-criticism",
|
||||||
|
"speak": "thoughts summary to say to user",
|
||||||
|
},
|
||||||
|
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
### Whether the model service is streaming output
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
chat_plugin_prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatExecution.value,
|
||||||
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
|
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=PluginChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt})
|
65
pilot/scene/chat_execution/prompt_with_command.py
Normal file
65
pilot/scene/chat_execution/prompt_with_command.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import json
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
|
||||||
|
|
||||||
|
PROMPT_SUFFIX = """
|
||||||
|
Goals:
|
||||||
|
{input}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
Constraints:
|
||||||
|
Exclusively use the commands listed in double quotes e.g. "command name"
|
||||||
|
Reflect on past decisions and strategies to refine your approach.
|
||||||
|
Constructively self-criticize your big-picture behavior constantly.
|
||||||
|
{constraints}
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
{commands_infos}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||||
|
{response}
|
||||||
|
|
||||||
|
Ensure the response is correct json and can be parsed by Python json.loads
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = {
|
||||||
|
"thoughts": {
|
||||||
|
"text": "thought",
|
||||||
|
"reasoning": "reasoning",
|
||||||
|
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||||
|
"criticism": "constructive self-criticism",
|
||||||
|
"speak": "thoughts summary to say to user",
|
||||||
|
},
|
||||||
|
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
### Whether the model service is streaming output
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
chat_plugin_prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatExecution.value,
|
||||||
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
|
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=PluginChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt})
|
@@ -1,19 +1,17 @@
|
|||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
from pilot.scene.chat_db.chat import ChatWithDb
|
from pilot.scene.chat_db.chat import ChatWithDb
|
||||||
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||||
|
|
||||||
class ChatFactory(metaclass=Singleton):
|
|
||||||
|
|
||||||
|
class ChatFactory(metaclass=Singleton):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_implementation(chat_mode, **kwargs):
|
def get_implementation(chat_mode, **kwargs):
|
||||||
|
|
||||||
chat_classes = BaseChat.__subclasses__()
|
chat_classes = BaseChat.__subclasses__()
|
||||||
implementation = None
|
implementation = None
|
||||||
for cls in chat_classes:
|
for cls in chat_classes:
|
||||||
if(cls.chat_scene == chat_mode):
|
if cls.chat_scene == chat_mode:
|
||||||
implementation = cls(**kwargs)
|
implementation = cls(**kwargs)
|
||||||
if(implementation == None):
|
if implementation == None:
|
||||||
raise Exception('Invalid implementation name:' + chat_mode)
|
raise Exception("Invalid implementation name:" + chat_mode)
|
||||||
return implementation
|
return implementation
|
||||||
|
@@ -9,12 +9,20 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict
|
from pilot.scene.base_message import (
|
||||||
|
BaseMessage,
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ViewMessage,
|
||||||
|
messages_to_dict,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OnceConversation:
|
class OnceConversation:
|
||||||
"""
|
"""
|
||||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -26,7 +34,9 @@ class OnceConversation:
|
|||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_user_message(self, message: str) -> None:
|
||||||
"""Add a user message to the store"""
|
"""Add a user message to the store"""
|
||||||
has_message = any(isinstance(instance, HumanMessage) for instance in self.messages)
|
has_message = any(
|
||||||
|
isinstance(instance, HumanMessage) for instance in self.messages
|
||||||
|
)
|
||||||
if has_message:
|
if has_message:
|
||||||
raise ValueError("Already Have Human message")
|
raise ValueError("Already Have Human message")
|
||||||
self.messages.append(HumanMessage(content=message))
|
self.messages.append(HumanMessage(content=message))
|
||||||
@@ -38,6 +48,7 @@ class OnceConversation:
|
|||||||
raise ValueError("Already Have Ai message")
|
raise ValueError("Already Have Ai message")
|
||||||
self.messages.append(AIMessage(content=message))
|
self.messages.append(AIMessage(content=message))
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
def add_view_message(self, message: str) -> None:
|
def add_view_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
|
||||||
@@ -50,7 +61,7 @@ class OnceConversation:
|
|||||||
|
|
||||||
def set_start_time(self, datatime: datetime):
|
def set_start_time(self, datatime: datetime):
|
||||||
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
self.start_date = dt_str;
|
self.start_date = dt_str
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Remove all messages from the store"""
|
"""Remove all messages from the store"""
|
||||||
@@ -71,7 +82,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
|||||||
"start_date": start_str,
|
"start_date": start_str,
|
||||||
"cost": once.cost if once.cost else 0,
|
"cost": once.cost if once.cost else 0,
|
||||||
"tokens": once.tokens if once.tokens else 0,
|
"tokens": once.tokens if once.tokens else 0,
|
||||||
"messages": messages_to_dict(once.messages)
|
"messages": messages_to_dict(once.messages),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -81,10 +92,10 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
|||||||
|
|
||||||
def conversation_from_dict(once: dict) -> OnceConversation:
|
def conversation_from_dict(once: dict) -> OnceConversation:
|
||||||
conversation = OnceConversation()
|
conversation = OnceConversation()
|
||||||
conversation.cost = once.get('cost', 0)
|
conversation.cost = once.get("cost", 0)
|
||||||
conversation.tokens = once.get('tokens', 0)
|
conversation.tokens = once.get("tokens", 0)
|
||||||
conversation.start_date = once.get('start_date', '')
|
conversation.start_date = once.get("start_date", "")
|
||||||
conversation.chat_order = int(once.get('chat_order'))
|
conversation.chat_order = int(once.get("chat_order"))
|
||||||
print(once.get('messages'))
|
print(once.get("messages"))
|
||||||
conversation.messages = messages_from_dict(once.get('messages', []))
|
conversation.messages = messages_from_dict(once.get("messages", []))
|
||||||
return conversation
|
return conversation
|
||||||
|
@@ -30,7 +30,7 @@ from pilot.configs.model_config import (
|
|||||||
LOGDIR,
|
LOGDIR,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
)
|
)
|
||||||
from pilot.connections.mysql import MySQLOperator
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
SeparatorStyle,
|
SeparatorStyle,
|
||||||
conv_qa_prompt_template,
|
conv_qa_prompt_template,
|
||||||
@@ -39,9 +39,9 @@ from pilot.conversation import (
|
|||||||
conversation_types,
|
conversation_types,
|
||||||
default_conversation,
|
default_conversation,
|
||||||
)
|
)
|
||||||
from pilot.plugins import scan_plugins
|
from pilot.common.plugins import scan_plugins
|
||||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
|
||||||
from pilot.prompts.generator import PromptGenerator
|
from pilot.prompts.generator import PluginPromptGenerator
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
@@ -95,19 +95,14 @@ def get_simlar(q):
|
|||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
mo = MySQLOperator(**DB_SETTINGS)
|
|
||||||
|
|
||||||
message = ""
|
message = ""
|
||||||
|
db_connect = CFG.local_db.get_session(dbname)
|
||||||
schemas = mo.get_schema(dbname)
|
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||||
for s in schemas:
|
for s in schemas:
|
||||||
message += s["schema_info"] + ";"
|
message += s["schema_info"] + ";"
|
||||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||||
|
|
||||||
|
|
||||||
def get_database_list():
|
|
||||||
mo = MySQLOperator(**DB_SETTINGS)
|
|
||||||
return mo.get_db_list()
|
|
||||||
|
|
||||||
|
|
||||||
get_window_url_params = """
|
get_window_url_params = """
|
||||||
@@ -127,7 +122,6 @@ function() {
|
|||||||
def load_demo(url_params, request: gr.Request):
|
def load_demo(url_params, request: gr.Request):
|
||||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||||
|
|
||||||
# dbs = get_database_list()
|
|
||||||
dropdown_update = gr.Dropdown.update(visible=True)
|
dropdown_update = gr.Dropdown.update(visible=True)
|
||||||
if dbs:
|
if dbs:
|
||||||
gr.Dropdown.update(choices=dbs)
|
gr.Dropdown.update(choices=dbs)
|
||||||
@@ -137,13 +131,15 @@ def load_demo(url_params, request: gr.Request):
|
|||||||
unique_id = uuid.uuid1()
|
unique_id = uuid.uuid1()
|
||||||
state.conv_id = str(unique_id)
|
state.conv_id = str(unique_id)
|
||||||
|
|
||||||
return (state,
|
return (
|
||||||
dropdown_update,
|
state,
|
||||||
gr.Chatbot.update(visible=True),
|
dropdown_update,
|
||||||
gr.Textbox.update(visible=True),
|
gr.Chatbot.update(visible=True),
|
||||||
gr.Button.update(visible=True),
|
gr.Textbox.update(visible=True),
|
||||||
gr.Row.update(visible=True),
|
gr.Button.update(visible=True),
|
||||||
gr.Accordion.update(visible=True))
|
gr.Row.update(visible=True),
|
||||||
|
gr.Accordion.update(visible=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_conv_log_filename():
|
def get_conv_log_filename():
|
||||||
@@ -203,30 +199,31 @@ def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
|||||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||||
return ChatScene.ChatExecution
|
return ChatScene.ChatExecution
|
||||||
else:
|
else:
|
||||||
return ChatScene.ChatNormal
|
return ChatScene.ChatNormal
|
||||||
|
|
||||||
|
|
||||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(
|
||||||
|
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||||
|
):
|
||||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||||
print(f"当前对话模式:{scene.value}")
|
print(f"当前对话模式:{scene.value}")
|
||||||
model_name = CFG.LLM_MODEL
|
model_name = CFG.LLM_MODEL
|
||||||
|
|
||||||
if ChatScene.ChatWithDb == scene:
|
if ChatScene.ChatWithDb == scene:
|
||||||
logger.info("基于DB对话走新的模式!")
|
logger.info("基于DB对话走新的模式!")
|
||||||
chat_param ={
|
chat_param = {
|
||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"db_name": db_selector,
|
"db_name": db_selector,
|
||||||
"user_input": state.last_user_input
|
"user_input": state.last_user_input,
|
||||||
}
|
}
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||||
chat.call()
|
chat.call()
|
||||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
dbname = db_selector
|
dbname = db_selector
|
||||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||||
if state.skip_next:
|
if state.skip_next:
|
||||||
@@ -242,7 +239,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||||
if db_selector:
|
if db_selector:
|
||||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
new_state.append_message(
|
||||||
|
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||||
|
)
|
||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[1], None)
|
||||||
else:
|
else:
|
||||||
new_state.append_message(new_state.roles[0], query)
|
new_state.append_message(new_state.roles[0], query)
|
||||||
@@ -251,7 +250,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
state = new_state
|
state = new_state
|
||||||
|
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
@@ -263,16 +261,24 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
if mode == conversation_types["custome"] and not db_selector:
|
if mode == conversation_types["custome"] and not db_selector:
|
||||||
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
|
persist_dir = os.path.join(
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
|
||||||
|
)
|
||||||
print("向量数据库持久化地址: ", persist_dir)
|
print("向量数据库持久化地址: ", persist_dir)
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
file_path="",
|
||||||
|
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
|
||||||
|
vector_store_config={
|
||||||
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
},
|
||||||
|
)
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
docs = knowledge_embedding_client.similar_search(query, 1)
|
docs = knowledge_embedding_client.similar_search(query, 1)
|
||||||
context = [d.page_content for d in docs]
|
context = [d.page_content for d in docs]
|
||||||
prompt_template = PromptTemplate(
|
prompt_template = PromptTemplate(
|
||||||
template=conv_qa_prompt_template,
|
template=conv_qa_prompt_template,
|
||||||
input_variables=["context", "question"]
|
input_variables=["context", "question"],
|
||||||
)
|
)
|
||||||
result = prompt_template.format(context="\n".join(context), question=query)
|
result = prompt_template.format(context="\n".join(context), question=query)
|
||||||
state.messages[-2][1] = result
|
state.messages[-2][1] = result
|
||||||
@@ -285,7 +291,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": float(temperature),
|
"temperature": float(temperature),
|
||||||
"max_new_tokens": int(max_new_tokens),
|
"max_new_tokens": int(max_new_tokens),
|
||||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
"stop": state.sep
|
||||||
|
if state.sep_style == SeparatorStyle.SINGLE
|
||||||
|
else state.sep2,
|
||||||
}
|
}
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
|
||||||
@@ -295,8 +303,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Stream output
|
# Stream output
|
||||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
response = requests.post(
|
||||||
headers=headers, json=payload, stream=True, timeout=20)
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
stream=True,
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
@@ -309,12 +322,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||||
state.messages[-1][-1] = output
|
state.messages[-1][-1] = output
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
yield (state, state.to_gradio_chatbot()) + (
|
||||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
enable_btn,
|
||||||
|
enable_btn,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
yield (state, state.to_gradio_chatbot()) + (
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
enable_btn,
|
||||||
|
enable_btn,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
@@ -405,10 +429,14 @@ def build_single_model_ui():
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
label="最大输出Token数",
|
label="最大输出Token数",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
tabs = gr.Tabs()
|
tabs = gr.Tabs()
|
||||||
|
|
||||||
with tabs:
|
with tabs:
|
||||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||||
with tab_sql:
|
with tab_sql:
|
||||||
|
print("tab_sql in...")
|
||||||
# TODO A selector to choose database
|
# TODO A selector to choose database
|
||||||
with gr.Row(elem_id="db_selector"):
|
with gr.Row(elem_id="db_selector"):
|
||||||
db_selector = gr.Dropdown(
|
db_selector = gr.Dropdown(
|
||||||
@@ -423,8 +451,23 @@ def build_single_model_ui():
|
|||||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||||
|
|
||||||
|
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
|
||||||
|
with tab_plugin:
|
||||||
|
print("tab_plugin in...")
|
||||||
|
with gr.Row(elem_id="plugin_selector"):
|
||||||
|
# TODO
|
||||||
|
plugin_selector = gr.Dropdown(
|
||||||
|
label="请选择插件",
|
||||||
|
choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""],
|
||||||
|
value="datadance-ddl-excutor",
|
||||||
|
interactive=True,
|
||||||
|
show_label=True,
|
||||||
|
).style(container=False)
|
||||||
|
|
||||||
|
|
||||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||||
with tab_qa:
|
with tab_qa:
|
||||||
|
print("tab_qa in...")
|
||||||
mode = gr.Radio(
|
mode = gr.Radio(
|
||||||
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
|
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
|
||||||
)
|
)
|
||||||
@@ -483,7 +526,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -573,7 +616,6 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
)
|
)
|
||||||
knowledge_embedding_client.knowledge_embedding()
|
knowledge_embedding_client.knowledge_embedding()
|
||||||
|
|
||||||
|
|
||||||
logger.info("knowledge embedding success")
|
logger.info("knowledge embedding success")
|
||||||
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from pilot.vector_store.chroma_store import ChromaStore
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
# from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
|
||||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
|
connector = {"Chroma": ChromaStore, "Milvus": None}
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConnector:
|
class VectorStoreConnector:
|
||||||
|
@@ -54,7 +54,7 @@ gTTS==2.3.1
|
|||||||
langchain
|
langchain
|
||||||
nltk
|
nltk
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
pymilvus==2.2.1
|
# pymilvus==2.2.1
|
||||||
vcrpy
|
vcrpy
|
||||||
chromadb
|
chromadb
|
||||||
markdown2
|
markdown2
|
||||||
|
Reference in New Issue
Block a user