From 20edf6daaa5444d40a81b9bc7cbc9cef1bc05b23 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Mon, 29 May 2023 19:32:20 +0800 Subject: [PATCH 1/7] add plugin mode --- pilot/agent/json_fix_llm.py | 48 --- pilot/commands/command.py | 16 +- pilot/common/markdown_text.py | 11 +- pilot/{ => common}/plugins.py | 0 pilot/common/schema.py | 3 +- pilot/common/sql_database.py | 44 ++- pilot/configs/ai_config.py | 167 --------- pilot/configs/config.py | 14 +- pilot/connections/base.py | 31 +- pilot/connections/mysql.py | 64 ---- .../nosql/__init__.py} | 0 pilot/connections/rdbms/__init__.py | 0 pilot/connections/{ => rdbms}/clickhouse.py | 0 pilot/connections/{ => rdbms}/es.py | 0 pilot/connections/{ => rdbms}/mongo.py | 0 pilot/connections/rdbms/mysql.py | 18 + pilot/connections/{ => rdbms}/oracle.py | 0 pilot/connections/{ => rdbms}/postgres.py | 0 pilot/connections/rdbms/rdbms_connect.py | 318 ++++++++++++++++++ pilot/conversation.py | 1 + pilot/memory/chat_history/base.py | 10 +- pilot/memory/chat_history/file_history.py | 22 +- pilot/out_parser/base.py | 44 ++- pilot/prompts/auto_mode_prompt.py | 143 -------- pilot/prompts/base.py | 8 +- pilot/prompts/generator.py | 7 +- pilot/prompts/prompt.py | 73 ---- pilot/prompts/prompt_generator.py | 8 +- pilot/prompts/prompt_new.py | 14 +- pilot/prompts/prompt_template.py | 4 +- pilot/scene/base.py | 3 +- pilot/scene/base_chat.py | 180 +++++++++- pilot/scene/base_message.py | 8 +- pilot/scene/chat_db/chat.py | 214 ++++++------ pilot/scene/chat_db/out_parser.py | 40 +-- pilot/scene/chat_db/prompt.py | 14 +- pilot/scene/chat_execution/chat.py | 146 +++++++- pilot/scene/chat_execution/out_parser.py | 30 ++ pilot/scene/chat_execution/prompt.py | 65 ++++ .../chat_execution/prompt_with_command.py | 65 ++++ pilot/scene/chat_factory.py | 12 +- pilot/scene/message.py | 33 +- pilot/server/webserver.py | 122 ++++--- pilot/vector_store/connector.py | 4 +- requirements.txt | 2 +- 45 files changed, 1202 insertions(+), 804 deletions(-) rename pilot/{ => common}/plugins.py (100%) delete mode 100644 pilot/configs/ai_config.py delete mode 100644 pilot/connections/mysql.py rename pilot/{prompts/generator_new.py => connections/nosql/__init__.py} (100%) create mode 100644 pilot/connections/rdbms/__init__.py rename pilot/connections/{ => rdbms}/clickhouse.py (100%) rename pilot/connections/{ => rdbms}/es.py (100%) rename pilot/connections/{ => rdbms}/mongo.py (100%) create mode 100644 pilot/connections/rdbms/mysql.py rename pilot/connections/{ => rdbms}/oracle.py (100%) rename pilot/connections/{ => rdbms}/postgres.py (100%) create mode 100644 pilot/connections/rdbms/rdbms_connect.py delete mode 100644 pilot/prompts/auto_mode_prompt.py delete mode 100644 pilot/prompts/prompt.py create mode 100644 pilot/scene/chat_execution/out_parser.py create mode 100644 pilot/scene/chat_execution/prompt.py create mode 100644 pilot/scene/chat_execution/prompt_with_command.py diff --git a/pilot/agent/json_fix_llm.py b/pilot/agent/json_fix_llm.py index 327881a78..075634784 100644 --- a/pilot/agent/json_fix_llm.py +++ b/pilot/agent/json_fix_llm.py @@ -55,54 +55,6 @@ def fix_and_parse_json( logger.error("参数解析错误", e) -def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]: - """Fix the given JSON string to make it parseable and fully compliant with two techniques. - - Args: - json_string (str): The JSON string to fix. - - Returns: - str: The fixed JSON string. - """ - assistant_reply = assistant_reply.strip() - if assistant_reply.startswith("```json"): - assistant_reply = assistant_reply[7:] - if assistant_reply.endswith("```"): - assistant_reply = assistant_reply[:-3] - try: - return json.loads(assistant_reply) # just check the validity - except json.JSONDecodeError as e: # noqa: E722 - print(f"JSONDecodeError: {e}") - pass - - if assistant_reply.startswith("json "): - assistant_reply = assistant_reply[5:] - assistant_reply = assistant_reply.strip() - try: - return json.loads(assistant_reply) # just check the validity - except json.JSONDecodeError: # noqa: E722 - pass - - # Parse and print Assistant response - assistant_reply_json = fix_and_parse_json(assistant_reply) - logger.debug("Assistant reply JSON: %s", str(assistant_reply_json)) - if assistant_reply_json == {}: - assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets( - assistant_reply - ) - - logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json)) - if assistant_reply_json != {}: - return assistant_reply_json - - logger.error( - "Error: The following AI output couldn't be converted to a JSON:\n", - assistant_reply, - ) - if CFG.speak_mode: - say_text("I have received an invalid JSON response from the OpenAI API.") - - return {} def correct_json(json_to_load: str) -> str: diff --git a/pilot/commands/command.py b/pilot/commands/command.py index 0200ef6cd..8838efff1 100644 --- a/pilot/commands/command.py +++ b/pilot/commands/command.py @@ -4,10 +4,9 @@ import json from typing import Dict -from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques from pilot.commands.exception_not_commands import NotCommands from pilot.configs.config import Config -from pilot.prompts.generator import PromptGenerator +from pilot.prompts.generator import PluginPromptGenerator from pilot.speech import say_text @@ -24,8 +23,8 @@ def _resolve_pathlike_command_args(command_args): def execute_ai_response_json( - prompt: PromptGenerator, - ai_response: str, + prompt: PluginPromptGenerator, + ai_response, user_input: str = None, ) -> str: """ @@ -39,11 +38,8 @@ def execute_ai_response_json( """ cfg = Config() - try: - assistant_reply_json = fix_json_using_multiple_techniques(ai_response) - except (json.JSONDecodeError, ValueError, AttributeError) as e: - raise NotCommands("非可执行命令结构") - command_name, arguments = get_command(assistant_reply_json) + + command_name, arguments = get_command(ai_response) if cfg.speak_mode: say_text(f"I want to execute {command_name}") @@ -71,7 +67,7 @@ def execute_ai_response_json( def execute_command( command_name: str, arguments, - prompt: PromptGenerator, + prompt: PluginPromptGenerator, ): """Execute the command and return the result diff --git a/pilot/common/markdown_text.py b/pilot/common/markdown_text.py index 1d90ba645..1244160fd 100644 --- a/pilot/common/markdown_text.py +++ b/pilot/common/markdown_text.py @@ -1,21 +1,21 @@ -import markdown2 +import markdown2 import pandas as pd + def datas_to_table_html(data): df = pd.DataFrame(data[1:], columns=data[0]) table_style = """""" - html_table = df.to_html(index=False, escape=False) + html_table = df.to_html(index=False, escape=False) html = f"
{table_style}{html_table}" return html.replace("\n", " ") - def generate_markdown_table(data): - """\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n """ + """\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n""" # 获取表格列数 num_cols = len(data[0]) # 生成表头 @@ -41,6 +41,7 @@ def generate_markdown_table(data): return table + def generate_htm_table(data): markdown_text = generate_markdown_table(data) html_table = markdown2.markdown(markdown_text, extras=["tables"]) @@ -53,4 +54,4 @@ if __name__ == "__main__": table_style = """""" - print(table_style.replace("\n", " ")) \ No newline at end of file + print(table_style.replace("\n", " ")) diff --git a/pilot/plugins.py b/pilot/common/plugins.py similarity index 100% rename from pilot/plugins.py rename to pilot/common/plugins.py diff --git a/pilot/common/schema.py b/pilot/common/schema.py index f66bba1a6..cd462966c 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -1,8 +1,9 @@ from enum import auto, Enum from typing import List, Any + class SeparatorStyle(Enum): - SINGLE ="###" + SINGLE = "###" TWO = "" THREE = auto() FOUR = auto() diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 2c16869d5..2b8d6fe4b 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -30,16 +30,16 @@ class Database: """SQLAlchemy wrapper around a database.""" def __init__( - self, - engine, - schema: Optional[str] = None, - metadata: Optional[MetaData] = None, - ignore_tables: Optional[List[str]] = None, - include_tables: Optional[List[str]] = None, - sample_rows_in_table_info: int = 3, - indexes_in_table_info: bool = False, - custom_table_info: Optional[dict] = None, - view_support: bool = False, + self, + engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + indexes_in_table_info: bool = False, + custom_table_info: Optional[dict] = None, + view_support: bool = False, ): """Create engine from database URI.""" self._engine = engine @@ -119,7 +119,7 @@ class Database: @classmethod def from_uri( - cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> Database: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} @@ -148,7 +148,7 @@ class Database: self._metadata = MetaData() # sql = f"use {db_name}" - sql = text(f'use `{db_name}`') + sql = text(f"use `{db_name}`") session.execute(sql) # 处理表信息数据 @@ -159,13 +159,17 @@ class Database: # tables list if view_support is True self._all_tables = set( self._inspector.get_table_names(schema=db_name) - + (self._inspector.get_view_names(schema=db_name) if self.view_support else []) + + ( + self._inspector.get_view_names(schema=db_name) + if self.view_support + else [] + ) ) return session def get_current_db_name(self, session) -> str: - return session.execute(text('SELECT DATABASE()')).scalar() + return session.execute(text("SELECT DATABASE()")).scalar() def table_simple_info(self, session): _sql = f""" @@ -201,7 +205,7 @@ class Database: tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) - and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) + and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] @@ -214,7 +218,7 @@ class Database: create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" has_extra_info = ( - self._indexes_in_table_info or self._sample_rows_in_table_info + self._indexes_in_table_info or self._sample_rows_in_table_info ) if has_extra_info: table_info += "\n\n/*" @@ -303,6 +307,10 @@ class Database: def get_database_list(self): session = self._db_sessions() - cursor = session.execute(text(' show databases;')) + cursor = session.execute(text(" show databases;")) results = cursor.fetchall() - return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]] + return [ + d[0] + for d in results + if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] + ] diff --git a/pilot/configs/ai_config.py b/pilot/configs/ai_config.py deleted file mode 100644 index ed9b4e2f8..000000000 --- a/pilot/configs/ai_config.py +++ /dev/null @@ -1,167 +0,0 @@ -# sourcery skip: do-not-use-staticmethod -""" -A module that contains the AIConfig class object that contains the configuration -""" -from __future__ import annotations - -import os -import platform -from pathlib import Path -from typing import Optional - -import distro -import yaml - -from pilot.configs.config import Config -from pilot.prompts.generator import PromptGenerator -from pilot.prompts.prompt import build_default_prompt_generator - -# Soon this will go in a folder where it remembers more stuff about the run(s) -SAVE_FILE = str(Path(os.getcwd()) / "ai_settings.yaml") - - -class AIConfig: - """ - A class object that contains the configuration information for the AI - - Attributes: - ai_name (str): The name of the AI. - ai_role (str): The description of the AI's role. - ai_goals (list): The list of objectives the AI is supposed to complete. - api_budget (float): The maximum dollar value for API calls (0.0 means infinite) - """ - - def __init__( - self, - ai_name: str = "", - ai_role: str = "", - ai_goals: list | None = None, - api_budget: float = 0.0, - ) -> None: - """ - Initialize a class instance - - Parameters: - ai_name (str): The name of the AI. - ai_role (str): The description of the AI's role. - ai_goals (list): The list of objectives the AI is supposed to complete. - api_budget (float): The maximum dollar value for API calls (0.0 means infinite) - Returns: - None - """ - if ai_goals is None: - ai_goals = [] - self.ai_name = ai_name - self.ai_role = ai_role - self.ai_goals = ai_goals - self.api_budget = api_budget - self.prompt_generator = None - self.command_registry = None - - @staticmethod - def load(config_file: str = SAVE_FILE) -> "AIConfig": - """ - Returns class object with parameters (ai_name, ai_role, ai_goals, api_budget) loaded from - yaml file if yaml file exists, - else returns class with no parameters. - - Parameters: - config_file (int): The path to the config yaml file. - DEFAULT: "../ai_settings.yaml" - - Returns: - cls (object): An instance of given cls object - """ - - try: - with open(config_file, encoding="utf-8") as file: - config_params = yaml.load(file, Loader=yaml.FullLoader) - except FileNotFoundError: - config_params = {} - - ai_name = config_params.get("ai_name", "") - ai_role = config_params.get("ai_role", "") - ai_goals = [ - str(goal).strip("{}").replace("'", "").replace('"', "") - if isinstance(goal, dict) - else str(goal) - for goal in config_params.get("ai_goals", []) - ] - api_budget = config_params.get("api_budget", 0.0) - # type is Type[AIConfig] - return AIConfig(ai_name, ai_role, ai_goals, api_budget) - - def save(self, config_file: str = SAVE_FILE) -> None: - """ - Saves the class parameters to the specified file yaml file path as a yaml file. - - Parameters: - config_file(str): The path to the config yaml file. - DEFAULT: "../ai_settings.yaml" - - Returns: - None - """ - - config = { - "ai_name": self.ai_name, - "ai_role": self.ai_role, - "ai_goals": self.ai_goals, - "api_budget": self.api_budget, - } - with open(config_file, "w", encoding="utf-8") as file: - yaml.dump(config, file, allow_unicode=True) - - def construct_full_prompt( - self, prompt_generator: Optional[PromptGenerator] = None - ) -> str: - """ - Returns a prompt to the user with the class information in an organized fashion. - - Parameters: - None - - Returns: - full_prompt (str): A string containing the initial prompt for the user - including the ai_name, ai_role, ai_goals, and api_budget. - """ - - prompt_start = ( - "Your decisions must always be made independently without" - " seeking user assistance. Play to your strengths as an LLM and pursue" - " simple strategies with no legal complications." - "" - ) - - cfg = Config() - if prompt_generator is None: - prompt_generator = build_default_prompt_generator() - prompt_generator.goals = self.ai_goals - prompt_generator.name = self.ai_name - prompt_generator.role = self.ai_role - prompt_generator.command_registry = self.command_registry - for plugin in cfg.plugins: - if not plugin.can_handle_post_prompt(): - continue - prompt_generator = plugin.post_prompt(prompt_generator) - - if cfg.execute_local_commands: - # add OS info to prompt - os_name = platform.system() - os_info = ( - platform.platform(terse=True) - if os_name != "Linux" - else distro.name(pretty=True) - ) - - prompt_start += f"\nThe OS you are running on is: {os_info}" - - # Construct full prompt - full_prompt = f"You are {prompt_generator.name}, {prompt_generator.role}\n{prompt_start}\n\nGOALS:\n\n" - for i, goal in enumerate(self.ai_goals): - full_prompt += f"{i+1}. {goal}\n" - if self.api_budget > 0.0: - full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}" - self.prompt_generator = prompt_generator - full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}" - return full_prompt diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 88cbf5117..518275f34 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -47,7 +47,6 @@ class Config(metaclass=Singleton): self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" - self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.exit_key = os.getenv("EXIT_KEY", "n") self.image_provider = os.getenv("IMAGE_PROVIDER", True) @@ -104,8 +103,17 @@ class Config(metaclass=Singleton): self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") ### TODO Adapt to multiple types of libraries - self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) , - engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + self.local_db = Database.from_uri( + "mysql+pymysql://" + + self.LOCAL_DB_USER + + ":" + + self.LOCAL_DB_PASSWORD + + "@" + + self.LOCAL_DB_HOST + + ":" + + str(self.LOCAL_DB_PORT), + engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, + ) ### LLM Model Service Configuration self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") diff --git a/pilot/connections/base.py b/pilot/connections/base.py index ec41f9273..3905f410d 100644 --- a/pilot/connections/base.py +++ b/pilot/connections/base.py @@ -2,7 +2,34 @@ # -*- coding:utf-8 -*- """We need to design a base class. That other connector can Write with this""" +from abc import ABC, abstractmethod +from pydantic import BaseModel, Extra, Field, root_validator +from typing import Any, Iterable, List, Optional -class BaseConnection: - pass +class BaseConnect(BaseModel, ABC): + type + driver: str + + + def get_session(self, db_name: str): + pass + + + def get_table_names(self) -> Iterable[str]: + pass + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + pass + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + pass + + def get_index_info(self, table_names: Optional[List[str]] = None) -> str: + pass + + def get_database_list(self): + pass + + def run(self, session, command: str, fetch: str = "all") -> List: + pass \ No newline at end of file diff --git a/pilot/connections/mysql.py b/pilot/connections/mysql.py deleted file mode 100644 index 2f5a1e152..000000000 --- a/pilot/connections/mysql.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import pymysql - -class MySQLOperator: - """Connect MySQL Database fetch MetaData For LLM Prompt - Args: - - Usage: - """ - - default_db = ["information_schema", "performance_schema", "sys", "mysql"] - - def __init__(self, user, password, host="localhost", port=3306) -> None: - self.conn = pymysql.connect( - host=host, - user=user, - port=port, - passwd=password, - charset="utf8mb4", - cursorclass=pymysql.cursors.DictCursor, - ) - - def get_schema(self, schema_name): - with self.conn.cursor() as cursor: - _sql = f""" - select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME; - """ - cursor.execute(_sql) - results = cursor.fetchall() - return results - - - def run_sql(self, db_name:str, sql:str, fetch: str = "all"): - with self.conn.cursor() as cursor: - cursor.execute("USE " + db_name) - cursor.execute(sql) - if fetch == "all": - result = cursor.fetchall() - elif fetch == "one": - result = cursor.fetchone()[0] # type: ignore - else: - raise ValueError("Fetch parameter must be either 'one' or 'all'") - return str(result) - - def get_index(self, schema_name): - pass - - def get_db_list(self): - with self.conn.cursor() as cursor: - _sql = """ - show databases; - """ - cursor.execute(_sql) - results = cursor.fetchall() - - dbs = [ - d["Database"] for d in results if d["Database"] not in self.default_db - ] - return dbs - - def get_meta(self, schema_name): - pass diff --git a/pilot/prompts/generator_new.py b/pilot/connections/nosql/__init__.py similarity index 100% rename from pilot/prompts/generator_new.py rename to pilot/connections/nosql/__init__.py diff --git a/pilot/connections/rdbms/__init__.py b/pilot/connections/rdbms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/connections/clickhouse.py b/pilot/connections/rdbms/clickhouse.py similarity index 100% rename from pilot/connections/clickhouse.py rename to pilot/connections/rdbms/clickhouse.py diff --git a/pilot/connections/es.py b/pilot/connections/rdbms/es.py similarity index 100% rename from pilot/connections/es.py rename to pilot/connections/rdbms/es.py diff --git a/pilot/connections/mongo.py b/pilot/connections/rdbms/mongo.py similarity index 100% rename from pilot/connections/mongo.py rename to pilot/connections/rdbms/mongo.py diff --git a/pilot/connections/rdbms/mysql.py b/pilot/connections/rdbms/mysql.py new file mode 100644 index 000000000..9d99f3e9b --- /dev/null +++ b/pilot/connections/rdbms/mysql.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import pymysql +from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase + + +class MySQLConnect(RDBMSDatabase): + """Connect MySQL Database fetch MetaData For LLM Prompt + Args: + Usage: + """ + + type:str = "MySQL" + connect_url = "mysql+pymysql://" + + default_db = ["information_schema", "performance_schema", "sys", "mysql"] + diff --git a/pilot/connections/oracle.py b/pilot/connections/rdbms/oracle.py similarity index 100% rename from pilot/connections/oracle.py rename to pilot/connections/rdbms/oracle.py diff --git a/pilot/connections/postgres.py b/pilot/connections/rdbms/postgres.py similarity index 100% rename from pilot/connections/postgres.py rename to pilot/connections/rdbms/postgres.py diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/rdbms_connect.py new file mode 100644 index 000000000..d3cca616f --- /dev/null +++ b/pilot/connections/rdbms/rdbms_connect.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import warnings +from typing import Any, Iterable, List, Optional +from pydantic import BaseModel, Field, root_validator, validator, Extra +from abc import ABC, abstractmethod +import sqlalchemy +from sqlalchemy import ( + MetaData, + Table, + create_engine, + inspect, + select, + text, +) +from sqlalchemy.engine import CursorResult, Engine +from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.schema import CreateTable +from sqlalchemy.orm import sessionmaker, scoped_session + +from pilot.connections.base import BaseConnect + + +def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: + return ( + f'Name: {index["name"]}, Unique: {index["unique"]},' + f' Columns: {str(index["column_names"])}' + ) + + +class RDBMSDatabase(BaseConnect): + """SQLAlchemy wrapper around a database.""" + + def __init__( + self, + engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + indexes_in_table_info: bool = False, + custom_table_info: Optional[dict] = None, + view_support: bool = False, + ): + """Create engine from database URI.""" + self._engine = engine + self._schema = schema + if include_tables and ignore_tables: + raise ValueError("Cannot specify both include_tables and ignore_tables") + + self._inspector = inspect(self._engine) + session_factory = sessionmaker(bind=engine) + Session = scoped_session(session_factory) + + self._db_sessions = Session + + self._all_tables = set() + self.view_support = False + self._usable_tables = set() + self._include_tables = set() + self._ignore_tables = set() + self._custom_table_info = set() + self._indexes_in_table_info = set() + self._usable_tables = set() + self._usable_tables = set() + self._sample_rows_in_table_info = set() + # including view support by adding the views as well as tables to the all + # tables list if view_support is True + # self._all_tables = set( + # self._inspector.get_table_names(schema=schema) + # + (self._inspector.get_view_names(schema=schema) if view_support else []) + # ) + + # self._include_tables = set(include_tables) if include_tables else set() + # if self._include_tables: + # missing_tables = self._include_tables - self._all_tables + # if missing_tables: + # raise ValueError( + # f"include_tables {missing_tables} not found in database" + # ) + # self._ignore_tables = set(ignore_tables) if ignore_tables else set() + # if self._ignore_tables: + # missing_tables = self._ignore_tables - self._all_tables + # if missing_tables: + # raise ValueError( + # f"ignore_tables {missing_tables} not found in database" + # ) + # usable_tables = self.get_usable_table_names() + # self._usable_tables = set(usable_tables) if usable_tables else self._all_tables + + # if not isinstance(sample_rows_in_table_info, int): + # raise TypeError("sample_rows_in_table_info must be an integer") + # + # self._sample_rows_in_table_info = sample_rows_in_table_info + # self._indexes_in_table_info = indexes_in_table_info + # + # self._custom_table_info = custom_table_info + # if self._custom_table_info: + # if not isinstance(self._custom_table_info, dict): + # raise TypeError( + # "table_info must be a dictionary with table names as keys and the " + # "desired table info as values" + # ) + # # only keep the tables that are also present in the database + # intersection = set(self._custom_table_info).intersection(self._all_tables) + # self._custom_table_info = dict( + # (table, self._custom_table_info[table]) + # for table in self._custom_table_info + # if table in intersection + # ) + + # self._metadata = metadata or MetaData() + # # # including view support if view_support = true + # self._metadata.reflect( + # views=view_support, + # bind=self._engine, + # only=list(self._usable_tables), + # schema=self._schema, + # ) + + @classmethod + def from_uri( + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> RDBMSDatabase: + """Construct a SQLAlchemy engine from URI.""" + _engine_args = engine_args or {} + return cls(create_engine(database_uri, **_engine_args), **kwargs) + + @property + def dialect(self) -> str: + """Return string representation of dialect to use.""" + return self._engine.dialect.name + + def get_usable_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + if self._include_tables: + return self._include_tables + return self._all_tables - self._ignore_tables + + def get_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + warnings.warn( + "This method is deprecated - please use `get_usable_table_names`." + ) + return self.get_usable_table_names() + + def get_session(self, db_name: str): + session = self._db_sessions() + + self._metadata = MetaData() + # sql = f"use {db_name}" + sql = text(f"use `{db_name}`") + session.execute(sql) + + # 处理表信息数据 + + self._metadata.reflect(bind=self._engine, schema=db_name) + + # including view support by adding the views as well as tables to the all + # tables list if view_support is True + self._all_tables = set( + self._inspector.get_table_names(schema=db_name) + + ( + self._inspector.get_view_names(schema=db_name) + if self.view_support + else [] + ) + ) + + return session + + def get_current_db_name(self, session) -> str: + return session.execute(text("SELECT DATABASE()")).scalar() + + def table_simple_info(self, session): + _sql = f""" + select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME; + """ + cursor = session.execute(text(_sql)) + results = cursor.fetchall() + return results + + @property + def table_info(self) -> str: + """Information about all tables in the database.""" + return self.get_table_info() + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables. + + Follows best practices as specified in: Rajkumar et al, 2022 + (https://arxiv.org/abs/2204.00498) + + If `sample_rows_in_table_info`, the specified number of sample rows will be + appended to each table description. This can increase performance as + demonstrated in the paper. + """ + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + all_table_names = table_names + + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) + ] + + tables = [] + for table in meta_tables: + if self._custom_table_info and table.name in self._custom_table_info: + tables.append(self._custom_table_info[table.name]) + continue + + # add create table command + create_table = str(CreateTable(table).compile(self._engine)) + table_info = f"{create_table.rstrip()}" + has_extra_info = ( + self._indexes_in_table_info or self._sample_rows_in_table_info + ) + if has_extra_info: + table_info += "\n\n/*" + if self._indexes_in_table_info: + table_info += f"\n{self._get_table_indexes(table)}\n" + if self._sample_rows_in_table_info: + table_info += f"\n{self._get_sample_rows(table)}\n" + if has_extra_info: + table_info += "*/" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def _get_sample_rows(self, table: Table) -> str: + # build the select command + command = select(table).limit(self._sample_rows_in_table_info) + + # save the columns in string format + columns_str = "\t".join([col.name for col in table.columns]) + + try: + # get the sample rows + with self._engine.connect() as connection: + sample_rows_result: CursorResult = connection.execute(command) + # shorten values in the sample rows + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) + ) + + # save the sample rows in string format + sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) + + # in some dialects when there are no rows in the table a + # 'ProgrammingError' is returned + except ProgrammingError: + sample_rows_str = "" + + return ( + f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" + f"{columns_str}\n" + f"{sample_rows_str}" + ) + + def _get_table_indexes(self, table: Table) -> str: + indexes = self._inspector.get_indexes(table.name) + indexes_formatted = "\n".join(map(_format_index, indexes)) + return f"Table Indexes:\n{indexes_formatted}" + + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables.""" + try: + return self.get_table_info(table_names) + except ValueError as e: + """Format the error message""" + return f"Error: {e}" + + def run(self, session, command: str, fetch: str = "all") -> List: + """Execute a SQL command and return a string representing the results.""" + cursor = session.execute(text(command)) + if cursor.returns_rows: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone()[0] # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + field_names = tuple(i[0:] for i in cursor.keys()) + + result = list(result) + result.insert(0, field_names) + return result + + def run_no_throw(self, session, command: str, fetch: str = "all") -> List: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + If the statement throws an error, the error message is returned. + """ + try: + return self.run(session, command, fetch) + except SQLAlchemyError as e: + """Format the error message""" + return f"Error: {e}" + + def get_database_list(self): + session = self._db_sessions() + cursor = session.execute(text(" show databases;")) + results = cursor.fetchall() + return [ + d[0] + for d in results + if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] + ] diff --git a/pilot/conversation.py b/pilot/conversation.py index 304469838..ba5ab2701 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -44,6 +44,7 @@ class Conversation: skip_next: bool = False conv_id: Any = None last_user_input: Any = None + def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index b71ad4a5f..8d60eafe7 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pydantic import BaseModel, Field, root_validator, validator,Extra +from pydantic import BaseModel, Field, root_validator, validator, Extra from abc import ABC, abstractmethod from typing import ( Any, @@ -17,13 +17,9 @@ from typing import ( from pilot.scene.message import OnceConversation - - - class BaseChatHistoryMemory(ABC): - def __init__(self): - self.conversations:List[OnceConversation] = [] + self.conversations: List[OnceConversation] = [] @abstractmethod def messages(self) -> List[OnceConversation]: # type: ignore @@ -33,8 +29,6 @@ class BaseChatHistoryMemory(ABC): def append(self, message: OnceConversation) -> None: """Append the message to the record in the local file""" - @abstractmethod def clear(self) -> None: """Clear session memory from the local file""" - diff --git a/pilot/memory/chat_history/file_history.py b/pilot/memory/chat_history/file_history.py index a3d53415b..ffdd4169b 100644 --- a/pilot/memory/chat_history/file_history.py +++ b/pilot/memory/chat_history/file_history.py @@ -5,31 +5,33 @@ import datetime from pilot.memory.chat_history.base import BaseChatHistoryMemory from pathlib import Path -from pilot.configs.config import Config -from pilot.scene.message import OnceConversation, conversation_from_dict,conversations_to_dict +from pilot.configs.config import Config +from pilot.scene.message import ( + OnceConversation, + conversation_from_dict, + conversations_to_dict, +) CFG = Config() class FileHistoryMemory(BaseChatHistoryMemory): - def __init__(self, chat_session_id:str): + def __init__(self, chat_session_id: str): now = datetime.datetime.now() date_string = now.strftime("%Y%m%d") path: str = f"{CFG.message_dir}/{date_string}" os.makedirs(path, exist_ok=True) dir_path = Path(path) - self.file_path = Path(dir_path / f"{chat_session_id}.json") + self.file_path = Path(dir_path / f"{chat_session_id}.json") if not self.file_path.exists(): self.file_path.touch() self.file_path.write_text(json.dumps([])) - - def messages(self) -> List[OnceConversation]: items = json.loads(self.file_path.read_text()) - history:List[OnceConversation] = [] + history: List[OnceConversation] = [] for onece in items: messages = conversation_from_dict(onece) history.append(messages) @@ -38,8 +40,10 @@ class FileHistoryMemory(BaseChatHistoryMemory): def append(self, once_message: OnceConversation) -> None: historys = self.messages() historys.append(once_message) - self.file_path.write_text(json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4), encoding="UTF-8") + self.file_path.write_text( + json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4), + encoding="UTF-8", + ) def clear(self) -> None: self.file_path.write_text(json.dumps([])) - diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 881c3d034..36ca8eb9c 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -13,13 +13,15 @@ from typing import ( TypeVar, Union, ) +from pilot.utils import build_logger +import re from pydantic import BaseModel, Extra, Field, root_validator - +from pilot.configs.model_config import LOGDIR from pilot.prompts.base import PromptValue T = TypeVar("T") - +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") class BaseOutputParser(ABC): """Class to parse the output of an LLM call. @@ -41,16 +43,16 @@ class BaseOutputParser(ABC): text = text.lower() respObj = json.loads(text) - xx = respObj['response'] - xx = xx.strip(b'\x00'.decode()) + xx = respObj["response"] + xx = xx.strip(b"\x00".decode()) respObj_ex = json.loads(xx) - if respObj_ex['error_code'] == 0: - all_text = respObj_ex['text'] + if respObj_ex["error_code"] == 0: + all_text = respObj_ex["text"] ### 解析返回文本,获取AI回复部分 tmpResp = all_text.split(sep) last_index = -1 for i in range(len(tmpResp)): - if tmpResp[i].find('assistant:') != -1: + if tmpResp[i].find("assistant:") != -1: last_index = i ai_response = tmpResp[last_index] ai_response = ai_response.replace("assistant:", "") @@ -60,9 +62,7 @@ class BaseOutputParser(ABC): print("un_stream clear response:{}", ai_response) return ai_response else: - raise ValueError("Model server error!code=" + respObj_ex['error_code']); - - + raise ValueError("Model server error!code=" + respObj_ex["error_code"]) def parse_model_server_out(self, response) -> str: """ @@ -87,7 +87,27 @@ class BaseOutputParser(ABC): Returns: """ - pass + cleaned_output = model_out_text.rstrip() + if "```json" in cleaned_output: + _, cleaned_output = cleaned_output.split("```json") + if "```" in cleaned_output: + cleaned_output, _ = cleaned_output.split("```") + if cleaned_output.startswith("```json"): + cleaned_output = cleaned_output[len("```json"):] + if cleaned_output.startswith("```"): + cleaned_output = cleaned_output[len("```"):] + if cleaned_output.endswith("```"): + cleaned_output = cleaned_output[: -len("```")] + cleaned_output = cleaned_output.strip() + if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): + logger.info("illegal json processing") + json_pattern = r"{(.+?)}" + m = re.search(json_pattern, cleaned_output) + if m: + cleaned_output = m.group(0) + else: + raise ValueError("model server out not fllow the prompt!") + return cleaned_output def parse_view_response(self, ai_text) -> str: """ @@ -98,7 +118,7 @@ class BaseOutputParser(ABC): Returns: """ - pass + return ai_text def get_format_instructions(self) -> str: """Instructions on how the LLM output should be formatted.""" diff --git a/pilot/prompts/auto_mode_prompt.py b/pilot/prompts/auto_mode_prompt.py deleted file mode 100644 index b47d24a76..000000000 --- a/pilot/prompts/auto_mode_prompt.py +++ /dev/null @@ -1,143 +0,0 @@ -import platform -from typing import Optional - -import distro -import yaml - -from pilot.configs.config import Config -from pilot.prompts.generator import PromptGenerator -from pilot.prompts.prompt import ( - DEFAULT_PROMPT_OHTER, - DEFAULT_TRIGGERING_PROMPT, - build_default_prompt_generator, -) - - -class AutoModePrompt: - """ """ - - def __init__( - self, - ai_goals: list | None = None, - ) -> None: - """ - Initialize a class instance - - Parameters: - ai_name (str): The name of the AI. - ai_role (str): The description of the AI's role. - ai_goals (list): The list of objectives the AI is supposed to complete. - api_budget (float): The maximum dollar value for API calls (0.0 means infinite) - Returns: - None - """ - if ai_goals is None: - ai_goals = [] - self.ai_goals = ai_goals - self.prompt_generator = None - self.command_registry = None - - def construct_follow_up_prompt( - self, - user_input: [str], - last_auto_return: str = None, - prompt_generator: Optional[PromptGenerator] = None, - ) -> str: - """ - Build complete prompt information based on subsequent dialogue information entered by the user - Args: - self: - prompt_generator: - - Returns: - - """ - prompt_start = DEFAULT_PROMPT_OHTER - if prompt_generator is None: - prompt_generator = build_default_prompt_generator() - prompt_generator.goals = user_input - prompt_generator.command_registry = self.command_registry - # 加载插件中可用命令 - cfg = Config() - for plugin in cfg.plugins: - if not plugin.can_handle_post_prompt(): - continue - prompt_generator = plugin.post_prompt(prompt_generator) - - full_prompt = f"{prompt_start}\n\nGOALS:\n\n" - if not self.ai_goals: - self.ai_goals = user_input - for i, goal in enumerate(self.ai_goals): - full_prompt += ( - f"{i+1}.According to the provided Schema information, {goal}\n" - ) - # if last_auto_return == None: - # full_prompt += f"{cfg.last_plugin_return}\n\n" - # else: - # full_prompt += f"{last_auto_return}\n\n" - - full_prompt += f"Constraints:\n\n{DEFAULT_TRIGGERING_PROMPT}\n" - - full_prompt += """Based on the above definition, answer the current goal and ensure that the response meets both the current constraints and the above definition and constraints""" - self.prompt_generator = prompt_generator - return full_prompt - - def construct_first_prompt( - self, - fisrt_message: [str] = [], - db_schemes: str = None, - prompt_generator: Optional[PromptGenerator] = None, - ) -> str: - """ - Build complete prompt information based on the initial dialogue information entered by the user - Args: - self: - prompt_generator: - - Returns: - - """ - prompt_start = ( - "Your decisions must always be made independently without" - " seeking user assistance. Play to your strengths as an LLM and pursue" - " simple strategies with no legal complications." - "" - ) - if prompt_generator is None: - prompt_generator = build_default_prompt_generator() - prompt_generator.goals = fisrt_message - prompt_generator.command_registry = self.command_registry - # 加载插件中可用命令 - cfg = Config() - for plugin in cfg.plugins: - if not plugin.can_handle_post_prompt(): - continue - prompt_generator = plugin.post_prompt(prompt_generator) - if cfg.execute_local_commands: - # add OS info to prompt - os_name = platform.system() - os_info = ( - platform.platform(terse=True) - if os_name != "Linux" - else distro.name(pretty=True) - ) - - prompt_start += f"\nThe OS you are running on is: {os_info}" - - # Construct full prompt - full_prompt = f"{prompt_start}\n\nGOALS:\n\n" - if not self.ai_goals: - self.ai_goals = fisrt_message - for i, goal in enumerate(self.ai_goals): - full_prompt += ( - f"{i+1}.According to the provided Schema information,{goal}\n" - ) - if db_schemes: - full_prompt += f"\nSchema:\n\n" - full_prompt += f"{db_schemes}" - - # if self.api_budget > 0.0: - # full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}" - self.prompt_generator = prompt_generator - full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}" - return full_prompt diff --git a/pilot/prompts/base.py b/pilot/prompts/base.py index bd082000e..12a97e94f 100644 --- a/pilot/prompts/base.py +++ b/pilot/prompts/base.py @@ -1,5 +1,3 @@ - - import json from abc import ABC, abstractmethod from pathlib import Path @@ -8,7 +6,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union import yaml from pydantic import BaseModel, Extra, Field, root_validator -from pilot.scene.base_message import BaseMessage,HumanMessage,AIMessage, SystemMessage +from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage def get_buffer_string( @@ -29,7 +27,6 @@ def get_buffer_string( return "\n".join(string_messages) - class PromptValue(BaseModel, ABC): @abstractmethod def to_string(self) -> str: @@ -39,6 +36,7 @@ class PromptValue(BaseModel, ABC): def to_messages(self) -> List[BaseMessage]: """Return prompt as messages.""" + class ChatPromptValue(PromptValue): messages: List[BaseMessage] @@ -48,4 +46,4 @@ class ChatPromptValue(PromptValue): def to_messages(self) -> List[BaseMessage]: """Return prompt as messages.""" - return self.messages \ No newline at end of file + return self.messages diff --git a/pilot/prompts/generator.py b/pilot/prompts/generator.py index c470ff5a5..22f998a67 100644 --- a/pilot/prompts/generator.py +++ b/pilot/prompts/generator.py @@ -3,7 +3,7 @@ import json from typing import Any, Callable, Dict, List, Optional -class PromptGenerator: +class PluginPromptGenerator: """ A class for generating custom prompt strings based on constraints, commands, resources, and performance evaluations. @@ -133,6 +133,11 @@ class PromptGenerator: else: return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) + + def generate_commands_string(self)->str: + return f"{self._generate_numbered_list(self.commands, item_type='command')}" + + def generate_prompt_string(self) -> str: """ Generate a prompt string based on the constraints, commands, resources, diff --git a/pilot/prompts/prompt.py b/pilot/prompts/prompt.py deleted file mode 100644 index d46b69ad5..000000000 --- a/pilot/prompts/prompt.py +++ /dev/null @@ -1,73 +0,0 @@ -from pilot.configs.config import Config -from pilot.prompts.generator import PromptGenerator - -CFG = Config() - -DEFAULT_TRIGGERING_PROMPT = ( - "Determine which next command to use, and respond using the format specified above" -) - -DEFAULT_PROMPT_OHTER = "Previous response was excellent. Please response according to the requirements based on the new goal" - - -def build_default_prompt_generator() -> PromptGenerator: - """ - This function generates a prompt string that includes various constraints, - commands, resources, and performance evaluations. - - Returns: - str: The generated prompt string. - """ - - # Initialize the PromptGenerator object - prompt_generator = PromptGenerator() - - # Add constraints to the PromptGenerator object - # prompt_generator.add_constraint( - # "~4000 word limit for short term memory. Your short term memory is short, so" - # " immediately save important information to files." - # ) - prompt_generator.add_constraint( - "If you are unsure how you previously did something or want to recall past" - " events, thinking about similar events will help you remember." - ) - # prompt_generator.add_constraint("No user assistance") - - prompt_generator.add_constraint("Only output one correct JSON response at a time") - prompt_generator.add_constraint( - 'Exclusively use the commands listed in double quotes e.g. "command name"' - ) - prompt_generator.add_constraint( - "If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition" - ) - prompt_generator.add_constraint( - "The generated command args need to comply with the definition of the command" - ) - - # Add resources to the PromptGenerator object - # prompt_generator.add_resource( - # "Internet access for searches and information gathering." - # ) - # prompt_generator.add_resource("Long Term memory management.") - # prompt_generator.add_resource( - # "DB-GPT powered Agents for delegation of simple tasks." - # ) - # prompt_generator.add_resource("File output.") - - # Add performance evaluations to the PromptGenerator object - prompt_generator.add_performance_evaluation( - "Continuously review and analyze your actions to ensure you are performing to" - " the best of your abilities." - ) - prompt_generator.add_performance_evaluation( - "Constructively self-criticize your big-picture behavior constantly." - ) - prompt_generator.add_performance_evaluation( - "Reflect on past decisions and strategies to refine your approach." - ) - # prompt_generator.add_performance_evaluation( - # "Every command has a cost, so be smart and efficient. Aim to complete tasks in" - # " the least number of steps." - # ) - # prompt_generator.add_performance_evaluation("Write all code to a file.") - return prompt_generator diff --git a/pilot/prompts/prompt_generator.py b/pilot/prompts/prompt_generator.py index e0ffed4a6..1ec62d5c9 100644 --- a/pilot/prompts/prompt_generator.py +++ b/pilot/prompts/prompt_generator.py @@ -3,8 +3,8 @@ from typing import Any, Callable, Dict, List, Optional class PromptGenerator: """ - generating custom prompt strings based on constraints; - Compatible with AutoGpt Plugin; + generating custom prompt strings based on constraints; + Compatible with AutoGpt Plugin; """ def __init__(self) -> None: @@ -22,8 +22,6 @@ class PromptGenerator: self.role = "AI" self.response_format = None - - def add_command( self, command_label: str, @@ -51,4 +49,4 @@ class PromptGenerator: "args": command_args, "function": function, } - self.commands.append(command) \ No newline at end of file + self.commands.append(command) diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index a79be5171..389b1a33e 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -8,6 +8,7 @@ from pilot.common.formatting import formatter from pilot.out_parser.base import BaseOutputParser from pilot.common.schema import SeparatorStyle + def jinja2_formatter(template: str, **kwargs: Any) -> str: """Format a template using jinja2.""" try: @@ -32,22 +33,23 @@ class PromptTemplate(BaseModel, ABC): """A list of the names of the variables the prompt template expects.""" template_scene: str - template_define:str + template_define: str """this template define""" template: str """The prompt template.""" template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - response_format:str + response_format: str """default use stream out""" stream_out: bool = True """""" output_parser: BaseOutputParser = None """""" - sep:str = SeparatorStyle.SINGLE.value + sep: str = SeparatorStyle.SINGLE.value class Config: """Configuration for this pydantic object.""" + arbitrary_types_allowed = True @property @@ -96,10 +98,8 @@ class PromptTemplate(BaseModel, ABC): else: return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) - - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - """ + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs.""" kwargs["response"] = json.dumps(self.response_format, indent=4) return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) diff --git a/pilot/prompts/prompt_template.py b/pilot/prompts/prompt_template.py index ad597c33d..0a014b06e 100644 --- a/pilot/prompts/prompt_template.py +++ b/pilot/prompts/prompt_template.py @@ -207,6 +207,7 @@ class BasePromptTemplate(BaseModel, ABC): else: raise ValueError(f"{save_path} must be json or yaml") + class StringPromptValue(PromptValue): text: str @@ -219,7 +220,6 @@ class StringPromptValue(PromptValue): return [HumanMessage(content=self.text)] - class StringPromptTemplate(BasePromptTemplate, ABC): """String prompt should expose the format method, returning a prompt.""" @@ -360,4 +360,4 @@ class PromptTemplate(StringPromptTemplate): # For backwards compatibility. -Prompt = PromptTemplate \ No newline at end of file +Prompt = PromptTemplate diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 302510f2b..9fcc6fb31 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -1,8 +1,9 @@ from enum import Enum + class ChatScene(Enum): ChatWithDb = "chat_with_db" ChatExecution = "chat_execution" ChatKnowledge = "chat_default_knowledge" ChatNewKnowledge = "chat_new_knowledge" - ChatNormal = "chat_normal" \ No newline at end of file + ChatNormal = "chat_normal" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index bc36287a6..7a1c77781 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -1,4 +1,6 @@ from abc import ABC, abstractmethod +import datetime +import traceback from pydantic import BaseModel, Field, root_validator, validator, Extra from typing import ( Any, @@ -20,20 +22,27 @@ from pilot.prompts.prompt_new import PromptTemplate from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory -from pilot.configs.model_config import LOGDIR, DATASETS_DIR +from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.utils import ( build_logger, server_error_msg, ) -from pilot.common.schema import SeparatorStyle -from pilot.scene.base import ChatScene +from pilot.scene.base_message import ( + BaseMessage, + SystemMessage, + HumanMessage, + AIMessage, + ViewMessage, +) from pilot.configs.config import Config logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") headers = {"User-Agent": "dbgpt Client"} CFG = Config() -class BaseChat( ABC): - chat_scene:str = None + + +class BaseChat(ABC): + chat_scene: str = None llm_model: Any = None temperature: float = 0.6 max_new_tokens: int = 1024 @@ -42,17 +51,20 @@ class BaseChat( ABC): class Config: """Configuration for this pydantic object.""" + arbitrary_types_allowed = True def __init__(self, chat_mode, chat_session_id, current_user_input): self.chat_session_id = chat_session_id self.chat_mode = chat_mode - self.current_user_input:str = current_user_input + self.current_user_input: str = current_user_input self.llm_model = CFG.LLM_MODEL ### TODO self.memory = FileHistoryMemory(chat_session_id) ### load prompt template - self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] + self.prompt_template: PromptTemplate = CFG.prompt_templates[ + self.chat_mode.value + ] self.history_message: List[OnceConversation] = [] self.current_message: OnceConversation = OnceConversation() self.current_tokens_used: int = 0 @@ -69,15 +81,163 @@ class BaseChat( ABC): def chat_type(self) -> str: raise NotImplementedError("Not supported for this chat type.") + @abstractmethod + def generate_input_values(self): + pass + + @abstractmethod + def do_with_prompt_response(self, prompt_response): + pass + def call(self): - pass + input_values = self.generate_input_values() + + ### Chat sequence advance + self.current_message.chat_order = len(self.history_message) + 1 + self.current_message.add_user_message(self.current_user_input) + self.current_message.start_date = datetime.datetime.now() + # TODO + self.current_message.tokens = 0 + + current_prompt = self.prompt_template.format(**input_values) + + ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 + if self.history_message: + ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 + logger.info( + f"There are already {len(self.history_message)} rounds of conversations!" + ) + + self.current_message.add_system_message(current_prompt) + + payload = { + "model": self.llm_model, + "prompt": self.generate_llm_text(), + "temperature": float(self.temperature), + "max_new_tokens": int(self.max_new_tokens), + "stop": self.prompt_template.sep, + } + logger.info(f"Requert: \n{payload}") + ai_response_text = "" + try: + if not self.prompt_template.stream_out: + ### 走非流式的模型服务接口 + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) + + ### output parse + ai_response_text = ( + self.prompt_template.output_parser.parse_model_server_out(response) + ) + self.current_message.add_ai_message(ai_response_text) + prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + + result = self.do_with_prompt_response(prompt_define_response) + + if hasattr(prompt_define_response, "thoughts"): + if prompt_define_response.thoughts.get("speak"): + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response.thoughts.get("speak"), result + ) + ) + elif prompt_define_response.thoughts.get("reasoning"): + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response.thoughts.get("reasoning"), result + ) + ) + else: + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response.thoughts, result + ) + ) + else: + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response, result + ) + ) + else: + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + timeout=120, + ) + #TODO + + + + except Exception as e: + print(traceback.format_exc()) + logger.error("model response parase faild!" + str(e)) + self.current_message.add_view_message( + f"""ERROR!{str(e)}\n {ai_response_text} """ + ) + ### 对话记录存储 + self.memory.append(self.current_message) + + def generate_llm_text(self) -> str: + text = self.prompt_template.template_define + self.prompt_template.sep + ### 线处理历史信息 + if len(self.history_message) > self.chat_retention_rounds: + ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 + for first_message in self.history_message[0].messages: + if not isinstance(first_message, ViewMessage): + text += ( + first_message.type + + ":" + + first_message.content + + self.prompt_template.sep + ) + + index = self.chat_retention_rounds - 1 + for last_message in self.history_message[-index:].messages: + if not isinstance(last_message, ViewMessage): + text += ( + last_message.type + + ":" + + last_message.content + + self.prompt_template.sep + ) + + else: + ### 直接历史记录拼接 + for conversation in self.history_message: + for message in conversation.messages: + if not isinstance(message, ViewMessage): + text += ( + message.type + + ":" + + message.content + + self.prompt_template.sep + ) + + ### current conversation + for now_message in self.current_message.messages: + text += ( + now_message.type + ":" + now_message.content + self.prompt_template.sep + ) + + return text + def chat_show(self): pass + # 暂时为了兼容前端 def current_ai_response(self) -> str: - pass + for message in self.current_message.messages: + if message.type == "view": + return message.content + return None def _load_history(self, session_id: str) -> List[OnceConversation]: """ @@ -88,7 +248,7 @@ class BaseChat( ABC): """ return self.memory.messages() - def generate(self, p)->str: + def generate(self, p) -> str: """ generate context for LLM input Args: diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 5cd8c4426..56fbb3b20 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -15,6 +15,7 @@ from typing import ( from pydantic import BaseModel, Extra, Field, root_validator + class PromptValue(BaseModel, ABC): @abstractmethod def to_string(self) -> str: @@ -37,7 +38,6 @@ class BaseMessage(BaseModel): """Type of the message, used for serialization.""" - class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" @@ -49,7 +49,6 @@ class HumanMessage(BaseMessage): return "human" - class AIMessage(BaseMessage): """Type of message that is spoken by the AI.""" @@ -81,8 +80,6 @@ class SystemMessage(BaseMessage): return "system" - - class Generation(BaseModel): """Output of a single generation.""" @@ -94,7 +91,6 @@ class Generation(BaseModel): """May include things like reason for finishing (e.g. in OpenAI)""" - class ChatGeneration(Generation): """Output of a single generation.""" @@ -126,7 +122,6 @@ class LLMResult(BaseModel): """For arbitrary LLM provider specific output.""" - def _message_to_dict(message: BaseMessage) -> dict: return {"type": message.type, "data": message.dict()} @@ -149,6 +144,5 @@ def _message_from_dict(message: dict) -> BaseMessage: raise ValueError(f"Got unexpected type: {_type}") - def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: return [_message_from_dict(m) for m in messages] diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py index 72ad64508..37382b7d2 100644 --- a/pilot/scene/chat_db/chat.py +++ b/pilot/scene/chat_db/chat.py @@ -14,7 +14,13 @@ from sqlalchemy import ( ) from typing import Any, Iterable, List, Optional -from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage +from pilot.scene.base_message import ( + BaseMessage, + SystemMessage, + HumanMessage, + AIMessage, + ViewMessage, +) from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -25,22 +31,24 @@ from pilot.utils import ( build_logger, server_error_msg, ) -from pilot.common.markdown_text import generate_markdown_table, generate_htm_table, datas_to_table_html +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) from pilot.scene.chat_db.prompt import chat_db_prompt from pilot.out_parser.base import BaseOutputParser from pilot.scene.chat_db.out_parser import DbChatOutputParser CFG = Config() - class ChatWithDb(BaseChat): chat_scene: str = ChatScene.ChatWithDb.value """Number of results to return from the query""" def __init__(self, chat_session_id, db_name, user_input): - """ - """ + """ """ super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input) if not db_name: raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!") @@ -50,118 +58,126 @@ class ChatWithDb(BaseChat): self.db_connect = self.database.get_session(self.db_name) self.top_k: int = 5 - def call(self) -> str: + def generate_input_values(self): input_values = { - "input": self.current_user_input, - "top_k": str(self.top_k), - "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect), - # "stop": self.sep_style, - } + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect) + } + return input_values - ### Chat sequence advance - self.current_message.chat_order = len(self.history_message) + 1 - self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now() - # TODO - self.current_message.tokens = 0 - - current_prompt = self.prompt_template.format(**input_values) - - ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - if self.history_message: - ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - logger.info(f"There are already {len(self.history_message)} rounds of conversations!") - - self.current_message.add_system_message(current_prompt) - - payload = { - "model": self.llm_model, - "prompt": self.generate_llm_text(), - "temperature": float(self.temperature), - "max_new_tokens": int(self.max_new_tokens), - "stop": self.prompt_template.sep, - } - logger.info(f"Requert: \n{payload}") - ai_response_text = "" - try: - ### 走非流式的模型服务接口 - - response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120) - ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) - self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - - result = self.database.run(self.db_connect, prompt_define_response.sql) + def do_with_prompt_response(self, prompt_response): + return self.database.run(self.db_connect, prompt_response.sql) - if hasattr(prompt_define_response, 'thoughts'): - if prompt_define_response.thoughts.get("speak"): - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("speak"),result)) - elif prompt_define_response.thoughts.get("reasoning"): - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts.get("reasoning"), result)) - else: - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response(prompt_define_response.thoughts, result)) - else: - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response(prompt_define_response, result)) - - except Exception as e: - print(traceback.format_exc()) - logger.error("model response parase faild!" + str(e)) - self.current_message.add_view_message(f"""ERROR!{str(e)}\n {ai_response_text} """) - ### 对话记录存储 - self.memory.append(self.current_message) + # def call(self) -> str: + # input_values = { + # "input": self.current_user_input, + # "top_k": str(self.top_k), + # "dialect": self.database.dialect, + # "table_info": self.database.table_simple_info(self.db_connect), + # # "stop": self.sep_style, + # } + # + # ### Chat sequence advance + # self.current_message.chat_order = len(self.history_message) + 1 + # self.current_message.add_user_message(self.current_user_input) + # self.current_message.start_date = datetime.datetime.now() + # # TODO + # self.current_message.tokens = 0 + # + # current_prompt = self.prompt_template.format(**input_values) + # + # ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 + # if self.history_message: + # ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 + # logger.info( + # f"There are already {len(self.history_message)} rounds of conversations!" + # ) + # + # self.current_message.add_system_message(current_prompt) + # + # payload = { + # "model": self.llm_model, + # "prompt": self.generate_llm_text(), + # "temperature": float(self.temperature), + # "max_new_tokens": int(self.max_new_tokens), + # "stop": self.prompt_template.sep, + # } + # logger.info(f"Requert: \n{payload}") + # ai_response_text = "" + # try: + # ### 走非流式的模型服务接口 + # + # response = requests.post( + # urljoin(CFG.MODEL_SERVER, "generate"), + # headers=headers, + # json=payload, + # timeout=120, + # ) + # ai_response_text = ( + # self.prompt_template.output_parser.parse_model_server_out(response) + # ) + # self.current_message.add_ai_message(ai_response_text) + # prompt_define_response = ( + # self.prompt_template.output_parser.parse_prompt_response( + # ai_response_text + # ) + # ) + # + # result = self.database.run(self.db_connect, prompt_define_response.sql) + # + # if hasattr(prompt_define_response, "thoughts"): + # if prompt_define_response.thoughts.get("speak"): + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response.thoughts.get("speak"), result + # ) + # ) + # elif prompt_define_response.thoughts.get("reasoning"): + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response.thoughts.get("reasoning"), result + # ) + # ) + # else: + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response.thoughts, result + # ) + # ) + # else: + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response, result + # ) + # ) + # + # except Exception as e: + # print(traceback.format_exc()) + # logger.error("model response parase faild!" + str(e)) + # self.current_message.add_view_message( + # f"""ERROR!{str(e)}\n {ai_response_text} """ + # ) + # ### 对话记录存储 + # self.memory.append(self.current_message) def chat_show(self): ret = [] # 单论对话只能有一次User 记录 和一次 AI 记录 # TODO 推理过程前端展示。。。 for message in self.current_message.messages: - if (isinstance(message, HumanMessage)): + if isinstance(message, HumanMessage): ret[-1][-2] = message.content # 是否展示推理过程 - if (isinstance(message, ViewMessage)): + if isinstance(message, ViewMessage): ret[-1][-1] = message.content return ret - # 暂时为了兼容前端 - def current_ai_response(self) -> str: - for message in self.current_message.messages: - if message.type == 'view': - return message.content - return None - def generate_llm_text(self) -> str: - text = self.prompt_template.template_define + self.prompt_template.sep - ### 线处理历史信息 - if (len(self.history_message) > self.chat_retention_rounds): - ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 - for first_message in self.history_message[0].messages: - if not isinstance(first_message, ViewMessage): - text += first_message.type + ":" + first_message.content + self.prompt_template.sep - index = self.chat_retention_rounds - 1 - for last_message in self.history_message[-index:].messages: - if not isinstance(last_message, ViewMessage): - text += last_message.type + ":" + last_message.content + self.prompt_template.sep - - else: - ### 直接历史记录拼接 - for conversation in self.history_message: - for message in conversation.messages: - if not isinstance(message, ViewMessage): - text += message.type + ":" + message.content + self.prompt_template.sep - - ### current conversation - for now_message in self.current_message.messages: - text += now_message.type + ":" + now_message.content + self.prompt_template.sep - - return text @property def chat_type(self) -> str: diff --git a/pilot/scene/chat_db/out_parser.py b/pilot/scene/chat_db/out_parser.py index 1d2597f57..307aff680 100644 --- a/pilot/scene/chat_db/out_parser.py +++ b/pilot/scene/chat_db/out_parser.py @@ -1,52 +1,28 @@ import json import re from abc import ABC, abstractmethod -from typing import ( - Dict, - NamedTuple -) +from typing import Dict, NamedTuple import pandas as pd from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T from pilot.configs.model_config import LOGDIR + class SqlAction(NamedTuple): sql: str thoughts: Dict + logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + + class DbChatOutputParser(BaseOutputParser): + def __init__(self, sep: str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out) - def __init__(self, sep:str, is_stream_out: bool): - super().__init__(sep=sep, is_stream_out=is_stream_out ) - - - def parse_model_server_out(self, response) -> str: - return super().parse_model_server_out(response) def parse_prompt_response(self, model_out_text): - cleaned_output = model_out_text.rstrip() - if "```json" in cleaned_output: - _, cleaned_output = cleaned_output.split("```json") - if "```" in cleaned_output: - cleaned_output, _ = cleaned_output.split("```") - if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json"):] - if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```"):] - if cleaned_output.endswith("```"): - cleaned_output = cleaned_output[: -len("```")] - cleaned_output = cleaned_output.strip() - if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): - logger.info("illegal json processing") - json_pattern = r'{(.+?)}' - m = re.search(json_pattern, cleaned_output) - if m: - cleaned_output = m.group(0) - else: - raise ValueError("model server out not fllow the prompt!") - - response = json.loads(cleaned_output) + response = json.loads(super().parse_prompt_response(model_out_text)) sql, thoughts = response["sql"], response["thoughts"] return SqlAction(sql, thoughts) diff --git a/pilot/scene/chat_db/prompt.py b/pilot/scene/chat_db/prompt.py index 8ff1a2b1b..aeaf994c0 100644 --- a/pilot/scene/chat_db/prompt.py +++ b/pilot/scene/chat_db/prompt.py @@ -45,9 +45,15 @@ RESPONSE_FORMAT = { "reasoning": "reasoning", "speak": "thoughts summary to say to user", }, - "sql": "SQL Query to run" + "sql": "SQL Query to run", } +RESPONSE_FORMAT_SIMPLE = { + "thoughts": "thoughts summary to say to user", + "sql": "SQL Query to run", +} + + PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False @@ -55,11 +61,13 @@ PROMPT_NEED_NEED_STREAM_OUT = False chat_db_prompt = PromptTemplate( template_scene=ChatScene.ChatWithDb.value, input_variables=["input", "table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT, indent=4), + response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4), template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT), + output_parser=DbChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), ) CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt}) diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 5e85c4981..a5abadad0 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -1,26 +1,156 @@ +import requests +import datetime +from urllib.parse import urljoin from typing import List +import traceback from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.message import OnceConversation from pilot.scene.base import ChatScene +from pilot.configs.config import Config +from pilot.commands.command import execute_command +from pilot.prompts.generator import PluginPromptGenerator + +CFG = Config() class ChatWithPlugin(BaseChat): - chat_scene: str= ChatScene.ChatExecution.value - def __init__(self, chat_mode, chat_session_id, current_user_input): - super().__init__(chat_mode, chat_session_id, current_user_input) + chat_scene: str = ChatScene.ChatExecution.value + plugins_prompt_generator:PluginPromptGenerator + select_plugin: str = None - def call(self): - super().call() + def __init__(self, chat_mode, chat_session_id, current_user_input, select_plugin:str=None): + super().__init__(chat_mode, chat_session_id, current_user_input) + self.plugins_prompt_generator = PluginPromptGenerator() + self.plugins_prompt_generator.command_registry = self.command_registry + # 加载插件中可用命令 + self.select_plugin = select_plugin + if self.select_plugin: + for plugin in CFG.plugins: + if plugin. + else: + for plugin in CFG.plugins: + if not plugin.can_handle_post_prompt(): + continue + self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) + + + + + def generate_input_values(self): + input_values = { + "input": self.current_user_input, + "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints), + "commands_infos": self.plugins_prompt_generator.generate_commands_string() + } + return input_values + + def do_with_prompt_response(self, prompt_response): + ## plugin command run + return execute_command(str(prompt_response), self.plugins_prompt_generator) + + + # def call(self): + # input_values = { + # "input": self.current_user_input, + # "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints), + # "commands_infos": self.__get_comnands_promp_info() + # } + # + # ### Chat sequence advance + # self.current_message.chat_order = len(self.history_message) + 1 + # self.current_message.add_user_message(self.current_user_input) + # self.current_message.start_date = datetime.datetime.now() + # # TODO + # self.current_message.tokens = 0 + # + # current_prompt = self.prompt_template.format(**input_values) + # + # ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 + # if self.history_message: + # ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 + # logger.info( + # f"There are already {len(self.history_message)} rounds of conversations!" + # ) + # + # self.current_message.add_system_message(current_prompt) + # + # payload = { + # "model": self.llm_model, + # "prompt": self.generate_llm_text(), + # "temperature": float(self.temperature), + # "max_new_tokens": int(self.max_new_tokens), + # "stop": self.prompt_template.sep, + # } + # logger.info(f"Requert: \n{payload}") + # ai_response_text = "" + # try: + # ### 走非流式的模型服务接口 + # + # response = requests.post( + # urljoin(CFG.MODEL_SERVER, "generate"), + # headers=headers, + # json=payload, + # timeout=120, + # ) + # ai_response_text = ( + # self.prompt_template.output_parser.parse_model_server_out(response) + # ) + # self.current_message.add_ai_message(ai_response_text) + # prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + # + # + # ## plugin command run + # result = execute_command(prompt_define_response, self.plugins_prompt_generator) + # + # if hasattr(prompt_define_response, "thoughts"): + # if prompt_define_response.thoughts.get("speak"): + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response.thoughts.get("speak"), result + # ) + # ) + # elif prompt_define_response.thoughts.get("reasoning"): + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response.thoughts.get("reasoning"), result + # ) + # ) + # else: + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response.thoughts, result + # ) + # ) + # else: + # self.current_message.add_view_message( + # self.prompt_template.output_parser.parse_view_response( + # prompt_define_response, result + # ) + # ) + # + # except Exception as e: + # print(traceback.format_exc()) + # logger.error("model response parase faild!" + str(e)) + # self.current_message.add_view_message( + # f"""ERROR!{str(e)}\n {ai_response_text} """ + # ) + # ### 对话记录存储 + # self.memory.append(self.current_message) def chat_show(self): super().chat_show() - def _load_history(self, session_id: str) -> List[OnceConversation]: - return super()._load_history(session_id) + + def __list_to_prompt_str(list: List) -> str: + if not list: + separator = '\n' + return separator.join(list) + else: + return "" def generate(self, p) -> str: return super().generate(p) @property def chat_type(self) -> str: - return ChatScene.ChatExecution.value \ No newline at end of file + return ChatScene.ChatExecution.value diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py new file mode 100644 index 000000000..f3f9e683e --- /dev/null +++ b/pilot/scene/chat_execution/out_parser.py @@ -0,0 +1,30 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class PluginAction(NamedTuple): + command: Dict + thoughts: Dict + + + +class PluginChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + response = json.loads(super().parse_prompt_response(model_out_text)) + sql, thoughts = response["command"], response["thoughts"] + return PluginAction(sql, thoughts) + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py new file mode 100644 index 000000000..e3469d7c2 --- /dev/null +++ b/pilot/scene/chat_execution/prompt.py @@ -0,0 +1,65 @@ +import json +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_execution.out_parser import PluginChatOutputParser + + +CFG = Config() + +PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications.""" + +PROMPT_SUFFIX = """ +Goals: + {input} + +""" + +_DEFAULT_TEMPLATE = """ +Constraints: + Exclusively use the commands listed in double quotes e.g. "command name" + Reflect on past decisions and strategies to refine your approach. + Constructively self-criticize your big-picture behavior constantly. + {constraints} + +Commands: + {commands_infos} +""" + + +PROMPT_RESPONSE = """You must respond in JSON format as following format: +{response} + +Ensure the response is correct json and can be parsed by Python json.loads +""" + +RESPONSE_FORMAT = { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", + }, + "command": {"name": "command name", "args": {"arg name": "value"}}, +} + +PROMPT_SEP = SeparatorStyle.SINGLE.value +### Whether the model service is streaming output +PROMPT_NEED_NEED_STREAM_OUT = False + +chat_plugin_prompt = PromptTemplate( + template_scene=ChatScene.ChatExecution.value, + input_variables=["input", "table_info", "dialect", "top_k", "response"], + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + template_define=PROMPT_SCENE_DEFINE, + template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=PluginChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt}) diff --git a/pilot/scene/chat_execution/prompt_with_command.py b/pilot/scene/chat_execution/prompt_with_command.py new file mode 100644 index 000000000..e3469d7c2 --- /dev/null +++ b/pilot/scene/chat_execution/prompt_with_command.py @@ -0,0 +1,65 @@ +import json +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_execution.out_parser import PluginChatOutputParser + + +CFG = Config() + +PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications.""" + +PROMPT_SUFFIX = """ +Goals: + {input} + +""" + +_DEFAULT_TEMPLATE = """ +Constraints: + Exclusively use the commands listed in double quotes e.g. "command name" + Reflect on past decisions and strategies to refine your approach. + Constructively self-criticize your big-picture behavior constantly. + {constraints} + +Commands: + {commands_infos} +""" + + +PROMPT_RESPONSE = """You must respond in JSON format as following format: +{response} + +Ensure the response is correct json and can be parsed by Python json.loads +""" + +RESPONSE_FORMAT = { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", + }, + "command": {"name": "command name", "args": {"arg name": "value"}}, +} + +PROMPT_SEP = SeparatorStyle.SINGLE.value +### Whether the model service is streaming output +PROMPT_NEED_NEED_STREAM_OUT = False + +chat_plugin_prompt = PromptTemplate( + template_scene=ChatScene.ChatExecution.value, + input_variables=["input", "table_info", "dialect", "top_k", "response"], + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + template_define=PROMPT_SCENE_DEFINE, + template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=PluginChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt}) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 2d72fc7fe..97c547390 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -1,19 +1,17 @@ - from pilot.scene.base_chat import BaseChat from pilot.singleton import Singleton from pilot.scene.chat_db.chat import ChatWithDb from pilot.scene.chat_execution.chat import ChatWithPlugin -class ChatFactory(metaclass=Singleton): +class ChatFactory(metaclass=Singleton): @staticmethod def get_implementation(chat_mode, **kwargs): - chat_classes = BaseChat.__subclasses__() implementation = None for cls in chat_classes: - if(cls.chat_scene == chat_mode): + if cls.chat_scene == chat_mode: implementation = cls(**kwargs) - if(implementation == None): - raise Exception('Invalid implementation name:' + chat_mode) - return implementation \ No newline at end of file + if implementation == None: + raise Exception("Invalid implementation name:" + chat_mode) + return implementation diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 8dc3eaa3e..0203ec68c 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -9,12 +9,20 @@ from typing import ( List, ) -from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict +from pilot.scene.base_message import ( + BaseMessage, + AIMessage, + HumanMessage, + SystemMessage, + ViewMessage, + messages_to_dict, + messages_from_dict, +) class OnceConversation: """ - All the information of a conversation, the current single service in memory, can expand cache and database support distributed services + All the information of a conversation, the current single service in memory, can expand cache and database support distributed services """ def __init__(self): @@ -26,7 +34,9 @@ class OnceConversation: def add_user_message(self, message: str) -> None: """Add a user message to the store""" - has_message = any(isinstance(instance, HumanMessage) for instance in self.messages) + has_message = any( + isinstance(instance, HumanMessage) for instance in self.messages + ) if has_message: raise ValueError("Already Have Human message") self.messages.append(HumanMessage(content=message)) @@ -38,6 +48,7 @@ class OnceConversation: raise ValueError("Already Have Ai message") self.messages.append(AIMessage(content=message)) """ """ + def add_view_message(self, message: str) -> None: """Add an AI message to the store""" @@ -50,7 +61,7 @@ class OnceConversation: def set_start_time(self, datatime: datetime): dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S") - self.start_date = dt_str; + self.start_date = dt_str def clear(self) -> None: """Remove all messages from the store""" @@ -71,7 +82,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict: "start_date": start_str, "cost": once.cost if once.cost else 0, "tokens": once.tokens if once.tokens else 0, - "messages": messages_to_dict(once.messages) + "messages": messages_to_dict(once.messages), } @@ -81,10 +92,10 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: def conversation_from_dict(once: dict) -> OnceConversation: conversation = OnceConversation() - conversation.cost = once.get('cost', 0) - conversation.tokens = once.get('tokens', 0) - conversation.start_date = once.get('start_date', '') - conversation.chat_order = int(once.get('chat_order')) - print(once.get('messages')) - conversation.messages = messages_from_dict(once.get('messages', [])) + conversation.cost = once.get("cost", 0) + conversation.tokens = once.get("tokens", 0) + conversation.start_date = once.get("start_date", "") + conversation.chat_order = int(once.get("chat_order")) + print(once.get("messages")) + conversation.messages = messages_from_dict(once.get("messages", [])) return conversation diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 6e7d8b700..8fefdbfff 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -30,7 +30,7 @@ from pilot.configs.model_config import ( LOGDIR, VECTOR_SEARCH_TOP_K, ) -from pilot.connections.mysql import MySQLOperator + from pilot.conversation import ( SeparatorStyle, conv_qa_prompt_template, @@ -39,9 +39,9 @@ from pilot.conversation import ( conversation_types, default_conversation, ) -from pilot.plugins import scan_plugins -from pilot.prompts.auto_mode_prompt import AutoModePrompt -from pilot.prompts.generator import PromptGenerator +from pilot.common.plugins import scan_plugins + +from pilot.prompts.generator import PluginPromptGenerator from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot from pilot.server.vectordb_qa import KnownLedgeBaseQA @@ -95,19 +95,14 @@ def get_simlar(q): def gen_sqlgen_conversation(dbname): - mo = MySQLOperator(**DB_SETTINGS) - message = "" - - schemas = mo.get_schema(dbname) + db_connect = CFG.local_db.get_session(dbname) + schemas = CFG.local_db.table_simple_info(db_connect) for s in schemas: message += s["schema_info"] + ";" return f"数据库{dbname}的Schema信息如下: {message}\n" -def get_database_list(): - mo = MySQLOperator(**DB_SETTINGS) - return mo.get_db_list() get_window_url_params = """ @@ -127,7 +122,6 @@ function() { def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") - # dbs = get_database_list() dropdown_update = gr.Dropdown.update(visible=True) if dbs: gr.Dropdown.update(choices=dbs) @@ -137,13 +131,15 @@ def load_demo(url_params, request: gr.Request): unique_id = uuid.uuid1() state.conv_id = str(unique_id) - return (state, - dropdown_update, - gr.Chatbot.update(visible=True), - gr.Textbox.update(visible=True), - gr.Button.update(visible=True), - gr.Row.update(visible=True), - gr.Accordion.update(visible=True)) + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) def get_conv_log_filename(): @@ -203,30 +199,31 @@ def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene: elif mode == conversation_types["auto_execute_plugin"] and not db_selector: return ChatScene.ChatExecution else: - return ChatScene.ChatNormal + return ChatScene.ChatNormal -def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): +def http_bot( + state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request +): logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}") start_tstamp = time.time() - scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector) + scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector) print(f"当前对话模式:{scene.value}") model_name = CFG.LLM_MODEL if ChatScene.ChatWithDb == scene: logger.info("基于DB对话走新的模式!") - chat_param ={ + chat_param = { "chat_session_id": state.conv_id, "db_name": db_selector, - "user_input": state.last_user_input + "user_input": state.last_user_input, } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) chat.call() - state.messages[-1][-1] = f"{chat.current_ai_response()}" + state.messages[-1][-1] = f"{chat.current_ai_response()}" yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 else: - dbname = db_selector # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 if state.skip_next: @@ -242,7 +239,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 if db_selector: - new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) new_state.append_message(new_state.roles[1], None) else: new_state.append_message(new_state.roles[0], query) @@ -251,7 +250,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re new_state.conv_id = uuid.uuid4().hex state = new_state - prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["default_knownledge"] and not db_selector: @@ -263,16 +261,24 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["custome"] and not db_selector: - persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") + persist_dir = os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb" + ) print("向量数据库持久化地址: ", persist_dir) - knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["sentence-transforms"], + vector_store_config={ + "vector_store_name": vector_store_name["vs_name"], + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + }, + ) query = state.messages[-2][1] docs = knowledge_embedding_client.similar_search(query, 1) context = [d.page_content for d in docs] prompt_template = PromptTemplate( template=conv_qa_prompt_template, - input_variables=["context", "question"] + input_variables=["context", "question"], ) result = prompt_template.format(context="\n".join(context), question=query) state.messages[-2][1] = result @@ -285,7 +291,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re "prompt": prompt, "temperature": float(temperature), "max_new_tokens": int(max_new_tokens), - "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, + "stop": state.sep + if state.sep_style == SeparatorStyle.SINGLE + else state.sep2, } logger.info(f"Requert: \n{payload}") @@ -295,8 +303,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re try: # Stream output - response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, json=payload, stream=True, timeout=20) + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=20, + ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) @@ -309,12 +322,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( - disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return state.messages[-1][-1] = state.messages[-1][-1][:-1] @@ -405,10 +429,14 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) + + tabs = gr.Tabs() + with tabs: tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") with tab_sql: + print("tab_sql in...") # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( @@ -423,8 +451,23 @@ def build_single_model_ui(): sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) + tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN") + with tab_plugin: + print("tab_plugin in...") + with gr.Row(elem_id="plugin_selector"): + # TODO + plugin_selector = gr.Dropdown( + label="请选择插件", + choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""], + value="datadance-ddl-excutor", + interactive=True, + show_label=True, + ).style(container=False) + + tab_qa = gr.TabItem("知识问答", elem_id="QA") with tab_qa: + print("tab_qa in...") mode = gr.Radio( ["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话" ) @@ -483,7 +526,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -573,7 +616,6 @@ def knowledge_embedding_store(vs_id, files): ) knowledge_embedding_client.knowledge_embedding() - logger.info("knowledge embedding success") return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 06fad00f2..3ff473f1e 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,7 +1,7 @@ from pilot.vector_store.chroma_store import ChromaStore -from pilot.vector_store.milvus_store import MilvusStore +# from pilot.vector_store.milvus_store import MilvusStore -connector = {"Chroma": ChromaStore, "Milvus": MilvusStore} +connector = {"Chroma": ChromaStore, "Milvus": None} class VectorStoreConnector: diff --git a/requirements.txt b/requirements.txt index 19d8ca34e..9595854bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,7 +54,7 @@ gTTS==2.3.1 langchain nltk python-dotenv==1.0.0 -pymilvus==2.2.1 +# pymilvus==2.2.1 vcrpy chromadb markdown2 From 9511cdb10a129db09416bb035d9cc16c1133c6df Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 30 May 2023 10:45:12 +0800 Subject: [PATCH 2/7] add plugin mode --- pilot/prompts/prompt_generator.py | 52 ------------- pilot/scene/chat_db/chat.py | 17 ++-- pilot/scene/chat_execution/chat.py | 111 ++++----------------------- pilot/scene/chat_execution/prompt.py | 2 +- pilot/server/webserver.py | 87 +++++++++++++++------ 5 files changed, 84 insertions(+), 185 deletions(-) delete mode 100644 pilot/prompts/prompt_generator.py diff --git a/pilot/prompts/prompt_generator.py b/pilot/prompts/prompt_generator.py deleted file mode 100644 index 1ec62d5c9..000000000 --- a/pilot/prompts/prompt_generator.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional - - -class PromptGenerator: - """ - generating custom prompt strings based on constraints; - Compatible with AutoGpt Plugin; - """ - - def __init__(self) -> None: - """ - Initialize the PromptGenerator object with empty lists of constraints, - commands, resources, and performance evaluations. - """ - self.constraints = [] - self.commands = [] - self.resources = [] - self.performance_evaluation = [] - self.goals = [] - self.command_registry = None - self.name = "Bob" - self.role = "AI" - self.response_format = None - - def add_command( - self, - command_label: str, - command_name: str, - args=None, - function: Optional[Callable] = None, - ) -> None: - """ - Add a command to the commands list with a label, name, and optional arguments. - GB-GPT and Auto-GPT plugin registration command. - Args: - command_label (str): The label of the command. - command_name (str): The name of the command. - args (dict, optional): A dictionary containing argument names and their - values. Defaults to None. - function (callable, optional): A callable function to be called when - the command is executed. Defaults to None. - """ - if args is None: - args = {} - command_args = {arg_key: arg_value for arg_key, arg_value in args.items()} - command = { - "label": command_label, - "name": command_name, - "args": command_args, - "function": function, - } - self.commands.append(command) diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py index 37382b7d2..745e9804d 100644 --- a/pilot/scene/chat_db/chat.py +++ b/pilot/scene/chat_db/chat.py @@ -42,6 +42,7 @@ from pilot.scene.chat_db.out_parser import DbChatOutputParser CFG = Config() + class ChatWithDb(BaseChat): chat_scene: str = ChatScene.ChatWithDb.value @@ -49,7 +50,7 @@ class ChatWithDb(BaseChat): def __init__(self, chat_session_id, db_name, user_input): """ """ - super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input) + super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input) if not db_name: raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!") self.db_name = db_name @@ -60,17 +61,16 @@ class ChatWithDb(BaseChat): def generate_input_values(self): input_values = { - "input": self.current_user_input, - "top_k": str(self.top_k), - "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect) - } + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect) + } return input_values def do_with_prompt_response(self, prompt_response): return self.database.run(self.db_connect, prompt_response.sql) - # def call(self) -> str: # input_values = { # "input": self.current_user_input, @@ -176,9 +176,6 @@ class ChatWithDb(BaseChat): return ret - - - @property def chat_type(self) -> str: return ChatScene.ChatExecution.value diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index a5abadad0..210b2ad77 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -11,6 +11,8 @@ from pilot.configs.config import Config from pilot.commands.command import execute_command from pilot.prompts.generator import PluginPromptGenerator +from pilot.scene.chat_execution.prompt import chat_plugin_prompt + CFG = Config() class ChatWithPlugin(BaseChat): @@ -18,15 +20,19 @@ class ChatWithPlugin(BaseChat): plugins_prompt_generator:PluginPromptGenerator select_plugin: str = None - def __init__(self, chat_mode, chat_session_id, current_user_input, select_plugin:str=None): - super().__init__(chat_mode, chat_session_id, current_user_input) + def __init__(self, chat_session_id, user_input, plugin_selector:str=None): + super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input) self.plugins_prompt_generator = PluginPromptGenerator() - self.plugins_prompt_generator.command_registry = self.command_registry + self.plugins_prompt_generator.command_registry = CFG.command_registry # 加载插件中可用命令 - self.select_plugin = select_plugin + self.select_plugin = plugin_selector if self.select_plugin: for plugin in CFG.plugins: - if plugin. + if plugin._name == plugin_selector : + if not plugin.can_handle_post_prompt(): + continue + self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) + else: for plugin in CFG.plugins: if not plugin.can_handle_post_prompt(): @@ -39,7 +45,7 @@ class ChatWithPlugin(BaseChat): def generate_input_values(self): input_values = { "input": self.current_user_input, - "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints), + "constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)), "commands_infos": self.plugins_prompt_generator.generate_commands_string() } return input_values @@ -48,101 +54,12 @@ class ChatWithPlugin(BaseChat): ## plugin command run return execute_command(str(prompt_response), self.plugins_prompt_generator) - - # def call(self): - # input_values = { - # "input": self.current_user_input, - # "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints), - # "commands_infos": self.__get_comnands_promp_info() - # } - # - # ### Chat sequence advance - # self.current_message.chat_order = len(self.history_message) + 1 - # self.current_message.add_user_message(self.current_user_input) - # self.current_message.start_date = datetime.datetime.now() - # # TODO - # self.current_message.tokens = 0 - # - # current_prompt = self.prompt_template.format(**input_values) - # - # ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - # if self.history_message: - # ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - # logger.info( - # f"There are already {len(self.history_message)} rounds of conversations!" - # ) - # - # self.current_message.add_system_message(current_prompt) - # - # payload = { - # "model": self.llm_model, - # "prompt": self.generate_llm_text(), - # "temperature": float(self.temperature), - # "max_new_tokens": int(self.max_new_tokens), - # "stop": self.prompt_template.sep, - # } - # logger.info(f"Requert: \n{payload}") - # ai_response_text = "" - # try: - # ### 走非流式的模型服务接口 - # - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # ai_response_text = ( - # self.prompt_template.output_parser.parse_model_server_out(response) - # ) - # self.current_message.add_ai_message(ai_response_text) - # prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - # - # - # ## plugin command run - # result = execute_command(prompt_define_response, self.plugins_prompt_generator) - # - # if hasattr(prompt_define_response, "thoughts"): - # if prompt_define_response.thoughts.get("speak"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("speak"), result - # ) - # ) - # elif prompt_define_response.thoughts.get("reasoning"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("reasoning"), result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts, result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response, result - # ) - # ) - # - # except Exception as e: - # print(traceback.format_exc()) - # logger.error("model response parase faild!" + str(e)) - # self.current_message.add_view_message( - # f"""ERROR!{str(e)}\n {ai_response_text} """ - # ) - # ### 对话记录存储 - # self.memory.append(self.current_message) - def chat_show(self): super().chat_show() - def __list_to_prompt_str(list: List) -> str: - if not list: + def __list_to_prompt_str(self, list: List) -> str: + if list: separator = '\n' return separator.join(list) else: diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index e3469d7c2..44f564afe 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -52,7 +52,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False chat_plugin_prompt = PromptTemplate( template_scene=ChatScene.ChatExecution.value, - input_variables=["input", "table_info", "dialect", "top_k", "response"], + input_variables=["input", "constraints", "commands_infos", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), template_define=PROMPT_SCENE_DEFINE, template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8fefdbfff..13bf38e5b 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -103,6 +103,11 @@ def gen_sqlgen_conversation(dbname): return f"数据库{dbname}的Schema信息如下: {message}\n" +def plugins_select_info(): + plugins_infos: dict = {} + for plugin in CFG.plugins: + plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name}) + return plugins_infos get_window_url_params = """ @@ -188,26 +193,27 @@ def post_process_code(code): return code -def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene: - if mode == conversation_types["default_knownledge"] and not db_selector: - return ChatScene.ChatKnowledge - elif mode == conversation_types["custome"] and not db_selector: - return ChatScene.ChatNewKnowledge - elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: - return ChatScene.ChatWithDb - - elif mode == conversation_types["auto_execute_plugin"] and not db_selector: +def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: + if "插件模式" == selected: return ChatScene.ChatExecution + elif "知识问答" == selected: + if mode == conversation_types["default_knownledge"]: + return ChatScene.ChatKnowledge + elif mode == conversation_types["custome"]: + return ChatScene.ChatNewKnowledge else: - return ChatScene.ChatNormal + if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: + return ChatScene.ChatWithDb + + return ChatScene.ChatNormal def http_bot( - state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request + state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request ): - logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}") + logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}") start_tstamp = time.time() - scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector) + scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector) print(f"当前对话模式:{scene.value}") model_name = CFG.LLM_MODEL @@ -216,6 +222,17 @@ def http_bot( chat_param = { "chat_session_id": state.conv_id, "db_name": db_selector, + "current_user_input": state.last_user_input, + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + chat.call() + state.messages[-1][-1] = f"{chat.current_ai_response()}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + elif ChatScene.ChatExecution == scene: + logger.info("插件模式对话走新的模式!") + chat_param = { + "chat_session_id": state.conv_id, + "plugin_selector": plugin_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) @@ -362,8 +379,8 @@ def http_bot( block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -396,6 +413,11 @@ def change_tab(): autogpt = True +def change_func(xx): + print("123") + print(str(xx)) + + def build_single_model_ui(): notice_markdown = """ # DB-GPT @@ -430,11 +452,18 @@ def build_single_model_ui(): label="最大输出Token数", ) - tabs = gr.Tabs() + def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData + print(f"You selected {evt.value} at {evt.index} from {evt.target}") + return evt.value + + selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") + tabs.select(on_select, None, selected) + with tabs: tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") + tab_sql.select(on_select, None, None) with tab_sql: print("tab_sql in...") # TODO A selector to choose database @@ -452,18 +481,26 @@ def build_single_model_ui(): sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN") + # tab_plugin.select(change_func) with tab_plugin: print("tab_plugin in...") with gr.Row(elem_id="plugin_selector"): # TODO plugin_selector = gr.Dropdown( label="请选择插件", - choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""], - value="datadance-ddl-excutor", + choices=list(plugins_select_info().keys()), + value="", interactive=True, show_label=True, + type="value" ).style(container=False) + def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData + print(f"You selected {evt.value} at {evt.index} from {evt.target}") + return plugins_select_info().get(evt.value) + + plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") + plugin_selector.select(plugin_change, None, plugin_selected) tab_qa = gr.TabItem("知识问答", elem_id="QA") with tab_qa: @@ -517,7 +554,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -526,7 +563,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -534,7 +571,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) vs_add.click( @@ -557,10 +594,10 @@ def build_single_model_ui(): def build_webdemo(): with gr.Blocks( - title="数据库智能助手", - # theme=gr.themes.Base(), - theme=gr.themes.Default(), - css=block_css, + title="数据库智能助手", + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, ) as demo: url_params = gr.JSON(visible=False) ( From 5150cfcf55a20d39bdcbb7b4f95031eb0995bbc7 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 30 May 2023 17:20:37 +0800 Subject: [PATCH 3/7] add plugin mode --- pilot/commands/built_in/__init__.py | 0 pilot/commands/{ => built_in}/audio_text.py | 0 pilot/commands/{ => built_in}/image_gen.py | 0 pilot/commands/commands_load.py | 29 -- pilot/conversation.py | 10 +- pilot/language/lang_content_mapping.py | 11 +- pilot/out_parser/base.py | 39 ++- pilot/scene/base_chat.py | 19 +- pilot/scene/chat_knowledge/custom/__init__.py | 0 .../scene/chat_knowledge/default/__init__.py | 0 pilot/scene/chat_knowledge/url/__init__.py | 0 pilot/server/webserver.py | 250 ++++++++++-------- pilot/source_embedding/external/__init__.py | 0 pilot/source_embedding/knowledge_embedding.py | 7 + 14 files changed, 212 insertions(+), 153 deletions(-) create mode 100644 pilot/commands/built_in/__init__.py rename pilot/commands/{ => built_in}/audio_text.py (100%) rename pilot/commands/{ => built_in}/image_gen.py (100%) delete mode 100644 pilot/commands/commands_load.py create mode 100644 pilot/scene/chat_knowledge/custom/__init__.py create mode 100644 pilot/scene/chat_knowledge/default/__init__.py create mode 100644 pilot/scene/chat_knowledge/url/__init__.py create mode 100644 pilot/source_embedding/external/__init__.py diff --git a/pilot/commands/built_in/__init__.py b/pilot/commands/built_in/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/commands/audio_text.py b/pilot/commands/built_in/audio_text.py similarity index 100% rename from pilot/commands/audio_text.py rename to pilot/commands/built_in/audio_text.py diff --git a/pilot/commands/image_gen.py b/pilot/commands/built_in/image_gen.py similarity index 100% rename from pilot/commands/image_gen.py rename to pilot/commands/built_in/image_gen.py diff --git a/pilot/commands/commands_load.py b/pilot/commands/commands_load.py deleted file mode 100644 index a6fad3db2..000000000 --- a/pilot/commands/commands_load.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pilot.configs.config import Config -from pilot.prompts.generator import PromptGenerator -from pilot.prompts.prompt import build_default_prompt_generator - - -class CommandsLoad: - """ - Load Plugins Commands Info , help build system prompt! - """ - - def __init__(self) -> None: - self.command_registry = None - - def getCommandInfos( - self, prompt_generator: Optional[PromptGenerator] = None - ) -> str: - cfg = Config() - if prompt_generator is None: - prompt_generator = build_default_prompt_generator() - for plugin in cfg.plugins: - if not plugin.can_handle_post_prompt(): - continue - prompt_generator = plugin.post_prompt(prompt_generator) - self.prompt_generator = prompt_generator - command_infos = "" - command_infos += f"\n\n{prompt_generator.commands()}" - return command_infos diff --git a/pilot/conversation.py b/pilot/conversation.py index d2f6565ca..0673b49c3 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -263,6 +263,14 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回 # """ default_conversation = conv_one_shot + +chat_mode_title = { + "sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"), + "chat_use_plugin": get_lang_text("chat_use_plugin"), + "knowledge_qa": get_lang_text("knowledge_qa"), + +} + conversation_sql_mode = { "auto_execute_ai_response": get_lang_text("sql_generate_mode_direct"), "dont_execute_ai_response": get_lang_text("sql_generate_mode_none"), @@ -274,7 +282,7 @@ conversation_types = { "knowledge_qa_type_default_knowledge_base_dialogue" ), "custome": get_lang_text("knowledge_qa_type_add_knowledge_base_dialogue"), - "auto_execute_plugin": get_lang_text("dialogue_use_plugin"), + "url": get_lang_text("knowledge_qa_type_url_knowledge_dialogue"), } conv_templates = { diff --git a/pilot/language/lang_content_mapping.py b/pilot/language/lang_content_mapping.py index 5d165b51c..bcea7ed3c 100644 --- a/pilot/language/lang_content_mapping.py +++ b/pilot/language/lang_content_mapping.py @@ -14,17 +14,22 @@ lang_dicts = { "knowledge_qa_type_llm_native_dialogue": "LLM原生对话", "knowledge_qa_type_default_knowledge_base_dialogue": "默认知识库对话", "knowledge_qa_type_add_knowledge_base_dialogue": "新增知识库对话", - "dialogue_use_plugin": "对话使用插件", + "knowledge_qa_type_url_knowledge_dialogue": "URL网页知识对话", "create_knowledge_base": "新建知识库", "sql_schema_info": "数据库{}的Schema信息如下: {}\n", "current_dialogue_mode": "当前对话模式", "database_smart_assistant": "数据库智能助手", "sql_vs_setting": "自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力", "knowledge_qa": "知识问答", + "chat_use_plugin": "插件模式", + "dialogue_use_plugin": "对话使用插件", + "select_plugin": "选择插件", "configure_knowledge_base": "配置知识库", "new_klg_name": "新知识库名称", + "url_input_label": "输入网页地址", "add_as_new_klg": "添加为新知识库", "add_file_to_klg": "向知识库中添加文件", + "upload_file": "上传文件", "add_file": "添加文件", "upload_and_load_to_klg": "上传并加载到知识库", @@ -47,14 +52,18 @@ lang_dicts = { "knowledge_qa_type_llm_native_dialogue": "LLM native dialogue", "knowledge_qa_type_default_knowledge_base_dialogue": "Default documents", "knowledge_qa_type_add_knowledge_base_dialogue": "Added documents", + "knowledge_qa_type_url_knowledge_dialogue": "Chat with url", "dialogue_use_plugin": "Dialogue Extension", "create_knowledge_base": "Create Knowledge Base", "sql_schema_info": "the schema information of database {}: {}\n", "current_dialogue_mode": "Current dialogue mode", "database_smart_assistant": "Database smart assistant", "sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn", + "chat_use_plugin": "Plugin Mode", + "select_plugin": "Select Plugin", "knowledge_qa": "Documents QA", "configure_knowledge_base": "Configure Documents", + "url_input_label": "Please input url", "new_klg_name": "New document name", "add_as_new_klg": "Add as new documents", "add_file_to_klg": "Add file to documents", diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 36ca8eb9c..57f0a7f7e 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -18,11 +18,14 @@ import re from pydantic import BaseModel, Extra, Field, root_validator from pilot.configs.model_config import LOGDIR -from pilot.prompts.base import PromptValue +from pilot.configs.config import Config T = TypeVar("T") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +CFG = Config() + + class BaseOutputParser(ABC): """Class to parse the output of an LLM call. @@ -33,9 +36,39 @@ class BaseOutputParser(ABC): self.sep = sep self.is_stream_out = is_stream_out + def __post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + # TODO 后续和模型绑定 def _parse_model_stream_resp(self, response, sep: str): - pass + + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ + if data["error_code"] == 0: + if "vicuna" in CFG.LLM_MODEL: + + output = data["text"].strip() + else: + output = data["text"].strip() + + output = self.__post_process_code(output) + yield output + else: + output = ( + data["text"] + f" (error_code: {data['error_code']})" + ) + yield output def _parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() @@ -64,7 +97,7 @@ class BaseOutputParser(ABC): else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) - def parse_model_server_out(self, response) -> str: + def parse_model_server_out(self, response): """ parse the model server http response Args: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 7a1c77781..a376759bc 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import datetime import traceback +import json from pydantic import BaseModel, Field, root_validator, validator, Extra from typing import ( Any, @@ -41,6 +42,7 @@ headers = {"User-Agent": "dbgpt Client"} CFG = Config() + class BaseChat(ABC): chat_scene: str = None llm_model: Any = None @@ -89,8 +91,7 @@ class BaseChat(ABC): def do_with_prompt_response(self, prompt_response): pass - - def call(self): + def call(self, show_fn, state): input_values = self.generate_input_values() ### Chat sequence advance @@ -164,6 +165,7 @@ class BaseChat(ABC): prompt_define_response, result ) ) + show_fn(state, self.current_ai_response()) else: response = requests.post( urljoin(CFG.MODEL_SERVER, "generate_stream"), @@ -171,9 +173,14 @@ class BaseChat(ABC): json=payload, timeout=120, ) - #TODO - + show_fn(state, "▌") + ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) + show_info ="" + for resp_text_trunck in ai_response_text: + show_info = resp_text_trunck + show_fn(state, resp_text_trunck + "▌") + self.current_message.add_ai_message(show_info) except Exception as e: print(traceback.format_exc()) @@ -181,9 +188,11 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) + show_fn(state, self.current_ai_response()) ### 对话记录存储 self.memory.append(self.current_message) + def generate_llm_text(self) -> str: text = self.prompt_template.template_define + self.prompt_template.sep ### 线处理历史信息 @@ -229,8 +238,6 @@ class BaseChat(ABC): return text - def chat_show(self): - pass # 暂时为了兼容前端 def current_ai_response(self) -> str: diff --git a/pilot/scene/chat_knowledge/custom/__init__.py b/pilot/scene/chat_knowledge/custom/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/default/__init__.py b/pilot/scene/chat_knowledge/default/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/url/__init__.py b/pilot/scene/chat_knowledge/url/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 9d04280d6..2f77277aa 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -37,6 +37,7 @@ from pilot.conversation import ( conv_templates, conversation_sql_mode, conversation_types, + chat_mode_title, default_conversation, ) from pilot.common.plugins import scan_plugins @@ -95,6 +96,11 @@ default_knowledge_base_dialogue = get_lang_text( add_knowledge_base_dialogue = get_lang_text( "knowledge_qa_type_add_knowledge_base_dialogue" ) + +url_knowledge_dialogue = get_lang_text( + "knowledge_qa_type_url_knowledge_dialogue" +) + knowledge_qa_type_list = [ llm_native_dialogue, default_knowledge_base_dialogue, @@ -115,7 +121,7 @@ def gen_sqlgen_conversation(dbname): db_connect = CFG.local_db.get_session(dbname) schemas = CFG.local_db.table_simple_info(db_connect) for s in schemas: - message += s["schema_info"] + ";" + message += s+ ";" return get_lang_text("sql_schema_info").format(dbname, message) @@ -211,9 +217,9 @@ def post_process_code(code): def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: - if "插件模式" == selected: + if chat_mode_title['chat_use_plugin'] == selected: return ChatScene.ChatExecution - elif "知识问答" == selected: + elif chat_mode_title['knowledge_qa'] == selected: if mode == conversation_types["default_knownledge"]: return ChatScene.ChatKnowledge elif mode == conversation_types["custome"]: @@ -226,37 +232,50 @@ def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: def http_bot( - state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request + state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request ): logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}") start_tstamp = time.time() - scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector) - print(f"当前对话模式:{scene.value}") + scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector) + print(f"now chat scene:{scene.value}") model_name = CFG.LLM_MODEL + def chatbot_callback(state, message): + print(f"chatbot_callback:{message}") + state.messages[-1][-1] = f"{message}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + if ChatScene.ChatWithDb == scene: - logger.info("基于DB对话走新的模式!") + logger.info("chat with db mode use new architecture design!") chat_param = { "chat_session_id": state.conv_id, "db_name": db_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call() - state.messages[-1][-1] = f"{chat.current_ai_response()}" - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat.call(show_fn=chatbot_callback, state= state) + elif ChatScene.ChatExecution == scene: - logger.info("插件模式对话走新的模式!") + logger.info("plugin mode use new architecture design!") chat_param = { "chat_session_id": state.conv_id, "plugin_selector": plugin_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call() - state.messages[-1][-1] = f"{chat.current_ai_response()}" - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat.call(chatbot_callback, state) + # def generate_numbers(): + # for i in range(10): + # time.sleep(0.5) + # yield f"Message:{i}" + # + # def showMessage(message): + # return message + # + # for n in generate_numbers(): + # state.messages[-1][-1] = n + "▌" + # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 else: dbname = db_selector @@ -284,30 +303,45 @@ def http_bot( new_state.conv_id = uuid.uuid4().hex state = new_state + else: + ### 后续对话 + query = state.messages[-2][1] + # 第一轮对话需要加入提示Prompt + if mode == conversation_types["custome"]: + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() + # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? + # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 + if db_selector: + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) + new_state.append_message(new_state.roles[1], None) + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + state = new_state prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["default_knownledge"] and not db_selector: + vector_store_config = { + "vector_store_name": "default", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) query = state.messages[-2][1] - knqa = KnownLedgeBaseQA() - state.messages[-2][1] = knqa.get_similar_answer(query) - prompt = state.get_prompt() + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["custome"] and not db_selector: - persist_dir = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb" - ) - print("向量数据库持久化地址: ", persist_dir) - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["sentence-transforms"], - vector_store_config={ - "vector_store_name": vector_store_name["vs_name"], - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - }, - ) print("vector store name: ", vector_store_name["vs_name"]) vector_store_config = { "vector_store_name": vector_store_name["vs_name"], @@ -327,6 +361,27 @@ def http_bot( state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 + if mode == conversation_types["url"] and url_input: + print("url: ", url_input) + vector_store_config = { + "vector_store_name": url_input, + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path=url_input, + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + query = state.messages[-2][1] + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) + + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 + # Make requests payload = { "model": model_name, @@ -355,13 +410,24 @@ def http_bot( for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ if data["error_code"] == 0: - output = data["text"][skip_echo_len:].strip() + if "vicuna" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len:].strip() + else: + output = data["text"].strip() + output = post_process_code(output) state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + ) * 5 else: - output = data["text"] + f" (error_code: {data['error_code']})" + output = ( + data["text"] + f" (error_code: {data['error_code']})" + ) state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( disable_btn, @@ -371,56 +437,7 @@ def http_bot( enable_btn, ) return - try: - # Stream output - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) - """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. - """ - if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: - output = data["text"][skip_echo_len:].strip() - else: - output = data["text"].strip() - - output = post_process_code(output) - state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - ) * 5 - else: - output = ( - data["text"] + f" (error_code: {data['error_code']})" - ) - state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - except requests.exceptions.RequestException as e: - state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" yield (state, state.to_gradio_chatbot()) + ( @@ -432,29 +449,29 @@ def http_bot( ) return - state.messages[-1][-1] = state.messages[-1][-1][:-1] - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - # 记录运行日志 - finish_tstamp = time.time() - logger.info(f"{output}") + # 记录运行日志 + finish_tstamp = time.time() + logger.info(f"{output}") - with open(get_conv_log_filename(), "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "start": round(start_tstamp, 4), - "finish": round(start_tstamp, 4), - "state": state.dict(), - "ip": request.client.host, - } - fout.write(json.dumps(data) + "\n") + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -477,15 +494,12 @@ def change_sql_mode(sql_mode): def change_mode(mode): - if mode in [default_knowledge_base_dialogue, llm_native_dialogue]: - return gr.update(visible=False) - else: + if mode in [add_knowledge_base_dialogue]: return gr.update(visible=True) + else: + return gr.update(visible=False) -def change_tab(): - autogpt = True - def build_single_model_ui(): notice_markdown = get_lang_text("db_gpt_introduction") @@ -548,15 +562,14 @@ def build_single_model_ui(): sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) - tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA") - tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN") + tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN") # tab_plugin.select(change_func) with tab_plugin: print("tab_plugin in...") with gr.Row(elem_id="plugin_selector"): # TODO plugin_selector = gr.Dropdown( - label="请选择插件", + label=get_lang_text("select_plugin"), choices=list(plugins_select_info().keys()), value="", interactive=True, @@ -578,6 +591,7 @@ def build_single_model_ui(): llm_native_dialogue, default_knowledge_base_dialogue, add_knowledge_base_dialogue, + url_knowledge_dialogue, ], show_label=False, value=llm_native_dialogue, @@ -586,6 +600,16 @@ def build_single_model_ui(): get_lang_text("configure_knowledge_base"), open=False ) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) + + url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True) + def show_url_input(evt:gr.SelectData): + if evt.value == url_knowledge_dialogue: + return gr.update(visible=True) + else: + return gr.update(visible=False) + mode.select(fn=show_url_input, inputs=None, outputs=url_input) + + with vs_setting: vs_name = gr.Textbox( label=get_lang_text("new_klg_name"), lines=1, interactive=True @@ -636,7 +660,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -645,7 +669,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -653,7 +677,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], [state, chatbot] + btn_list, ) vs_add.click( @@ -760,8 +784,8 @@ if __name__ == "__main__": # 加载插件可执行命令 command_categories = [ - "pilot.commands.audio_text", - "pilot.commands.image_gen", + "pilot.commands.built_in.audio_text", + "pilot.commands.built_in.image_gen", ] # 排除禁用命令 command_categories = [ diff --git a/pilot/source_embedding/external/__init__.py b/pilot/source_embedding/external/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 316667dee..8f411657d 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -11,6 +11,7 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding +from pilot.source_embedding.url_embedding import URLEmbedding from pilot.vector_store.connector import VectorStoreConnector CFG = Config() @@ -61,6 +62,12 @@ class KnowledgeEmbedding: model_name=self.model_name, vector_store_config=self.vector_store_config, ) + elif self.file_type == "url": + embedding = URLEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) return embedding From cd0469fdc4d7e1e0c94f34576726230f9508e708 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 30 May 2023 19:29:41 +0800 Subject: [PATCH 4/7] add plugin mode --- pilot/scene/base_chat.py | 214 +++++++++++++++++++++++++------------- pilot/server/webserver.py | 21 ++-- 2 files changed, 153 insertions(+), 82 deletions(-) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index a376759bc..1398a476a 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -1,3 +1,4 @@ +import time from abc import ABC, abstractmethod import datetime import traceback @@ -42,7 +43,6 @@ headers = {"User-Agent": "dbgpt Client"} CFG = Config() - class BaseChat(ABC): chat_scene: str = None llm_model: Any = None @@ -91,9 +91,8 @@ class BaseChat(ABC): def do_with_prompt_response(self, prompt_response): pass - def call(self, show_fn, state): - input_values = self.generate_input_values() - + def __call_base(self): + input_values = self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) @@ -120,67 +119,40 @@ class BaseChat(ABC): "stop": self.prompt_template.sep, } logger.info(f"Requert: \n{payload}") + return payload + + def stream_call(self): + payload = self.__call_base() + logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - if not self.prompt_template.stream_out: - ### 走非流式的模型服务接口 - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, - json=payload, - timeout=120, - ) + show_info = "" - ### output parse - ai_response_text = ( - self.prompt_template.output_parser.parse_model_server_out(response) - ) - self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + # response = requests.post( + # urljoin(CFG.MODEL_SERVER, "generate_stream"), + # headers=headers, + # json=payload, + # timeout=120, + # ) + # + # ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) - result = self.do_with_prompt_response(prompt_define_response) + # for resp_text_trunck in ai_response_text: + # show_info = resp_text_trunck + # yield resp_text_trunck + "▌" + # - if hasattr(prompt_define_response, "thoughts"): - if prompt_define_response.thoughts.get("speak"): - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response.thoughts.get("speak"), result - ) - ) - elif prompt_define_response.thoughts.get("reasoning"): - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response.thoughts.get("reasoning"), result - ) - ) - else: - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response.thoughts, result - ) - ) - else: - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response, result - ) - ) - show_fn(state, self.current_ai_response()) - else: - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - timeout=120, - ) - show_fn(state, "▌") - ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) - show_info ="" - for resp_text_trunck in ai_response_text: - show_info = resp_text_trunck - show_fn(state, resp_text_trunck + "▌") + #### MOCK TEST + def mock_stream_out(): + for i in range(1, 11): + time.sleep(0.5) + yield f"Message:{i}" - self.current_message.add_ai_message(show_info) + for msg in mock_stream_out(): + show_info = msg + yield msg + "▌" + + self.current_message.add_ai_message(show_info) except Exception as e: print(traceback.format_exc()) @@ -188,10 +160,72 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - show_fn(state, self.current_ai_response()) ### 对话记录存储 self.memory.append(self.current_message) + def nostream_call(self): + payload = self.__call_base() + logger.info(f"Requert: \n{payload}") + ai_response_text = "" + try: + ### 走非流式的模型服务接口 + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) + + ### output parse + ai_response_text = ( + self.prompt_template.output_parser.parse_model_server_out(response) + ) + self.current_message.add_ai_message(ai_response_text) + prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + + result = self.do_with_prompt_response(prompt_define_response) + + if hasattr(prompt_define_response, "thoughts"): + if prompt_define_response.thoughts.get("speak"): + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response.thoughts.get("speak"), result + ) + ) + elif prompt_define_response.thoughts.get("reasoning"): + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response.thoughts.get("reasoning"), result + ) + ) + else: + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response.thoughts, result + ) + ) + else: + self.current_message.add_view_message( + self.prompt_template.output_parser.parse_view_response( + prompt_define_response, result + ) + ) + + except Exception as e: + print(traceback.format_exc()) + logger.error("model response parase faild!" + str(e)) + self.current_message.add_view_message( + f"""ERROR!{str(e)}\n {ai_response_text} """ + ) + ### 对话记录存储 + self.memory.append(self.current_message) + return self.current_ai_response() + + def call(self): + if self.prompt_template.stream_out: + yield self.stream_call() + else: + return self.nostream_call() def generate_llm_text(self) -> str: text = self.prompt_template.template_define + self.prompt_template.sep @@ -201,20 +235,20 @@ class BaseChat(ABC): for first_message in self.history_message[0].messages: if not isinstance(first_message, ViewMessage): text += ( - first_message.type - + ":" - + first_message.content - + self.prompt_template.sep + first_message.type + + ":" + + first_message.content + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 for last_message in self.history_message[-index:].messages: if not isinstance(last_message, ViewMessage): text += ( - last_message.type - + ":" - + last_message.content - + self.prompt_template.sep + last_message.type + + ":" + + last_message.content + + self.prompt_template.sep ) else: @@ -223,22 +257,20 @@ class BaseChat(ABC): for message in conversation.messages: if not isinstance(message, ViewMessage): text += ( - message.type - + ":" - + message.content - + self.prompt_template.sep + message.type + + ":" + + message.content + + self.prompt_template.sep ) ### current conversation for now_message in self.current_message.messages: text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep + now_message.type + ":" + now_message.content + self.prompt_template.sep ) return text - - # 暂时为了兼容前端 def current_ai_response(self) -> str: for message in self.current_message.messages: @@ -265,3 +297,35 @@ class BaseChat(ABC): """ pass + + +if __name__ == "__main__": + # + # def call_back(t, m): + # print(t) + # print(m) + # + # def my_fn(call_fn, xx): + # call_fn(1, xx) + # + # + # my_fn(call_back, "1231") + + def my_generator(): + while True: + value = yield + print('Received value:', value) + if value == 'stop': + return + + + # 创建生成器对象 + gen = my_generator() + + # 启动生成器 + next(gen) + + # 发送数据到生成器 + gen.send('Hello') + gen.send('World') + gen.send('stop') diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 2f77277aa..2e8f61016 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -230,6 +230,11 @@ def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: return ChatScene.ChatNormal +def chatbot_callback(state, message): + print(f"chatbot_callback:{message}") + state.messages[-1][-1] = f"{message}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + def http_bot( state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request @@ -240,11 +245,6 @@ def http_bot( print(f"now chat scene:{scene.value}") model_name = CFG.LLM_MODEL - def chatbot_callback(state, message): - print(f"chatbot_callback:{message}") - state.messages[-1][-1] = f"{message}" - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - if ChatScene.ChatWithDb == scene: logger.info("chat with db mode use new architecture design!") chat_param = { @@ -253,7 +253,10 @@ def http_bot( "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call(show_fn=chatbot_callback, state= state) + chat.call() + + state.messages[-1][-1] = chat.current_ai_response() + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 elif ChatScene.ChatExecution == scene: logger.info("plugin mode use new architecture design!") @@ -263,7 +266,11 @@ def http_bot( "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call(chatbot_callback, state) + strem_generate = chat.stream_call() + + for msg in strem_generate: + state.messages[-1][-1] = msg + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 # def generate_numbers(): # for i in range(10): From 06bc4452d464277033f09f87d426b139c4294288 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Wed, 31 May 2023 15:59:50 +0800 Subject: [PATCH 5/7] Implemented a new multi-scenario dialogue architecture --- pilot/common/sql_database.py | 1 + pilot/conversation.py | 28 +- pilot/model/proxy_llm.py | 52 ++- pilot/out_parser/base.py | 5 +- pilot/prompts/prompt_new.py | 56 +-- pilot/scene/base.py | 4 +- pilot/scene/base_chat.py | 99 ++---- pilot/scene/chat_db/auto_execute/__init__.py | 0 pilot/scene/chat_db/auto_execute/chat.py | 57 +++ .../chat_db/{ => auto_execute}/out_parser.py | 4 +- .../chat_db/{ => auto_execute}/prompt.py | 27 +- pilot/scene/chat_db/chat.py | 240 ------------- .../scene/chat_db/professional_qa/__init__.py | 0 pilot/scene/chat_db/professional_qa/chat.py | 56 +++ .../chat_db/professional_qa/out_parser.py | 22 ++ pilot/scene/chat_db/professional_qa/prompt.py | 48 +++ pilot/scene/chat_execution/chat.py | 11 +- pilot/scene/chat_execution/out_parser.py | 4 +- pilot/scene/chat_execution/prompt.py | 5 +- .../chat_execution/prompt_with_command.py | 65 ---- pilot/scene/chat_factory.py | 12 +- pilot/scene/chat_knowledge/custom/chat.py | 69 ++++ .../scene/chat_knowledge/custom/out_parser.py | 22 ++ pilot/scene/chat_knowledge/custom/prompt.py | 43 +++ pilot/scene/chat_knowledge/default/chat.py | 66 ++++ .../chat_knowledge/default/out_parser.py | 22 ++ pilot/scene/chat_knowledge/default/prompt.py | 43 +++ pilot/scene/chat_knowledge/url/chat.py | 71 ++++ pilot/scene/chat_knowledge/url/out_parser.py | 22 ++ pilot/scene/chat_knowledge/url/prompt.py | 43 +++ pilot/scene/chat_normal/chat.py | 43 +++ pilot/scene/chat_normal/out_parser.py | 22 ++ pilot/scene/chat_normal/prompt.py | 50 +-- pilot/server/webserver.py | 335 ++++++------------ requirements.txt | 1 + 35 files changed, 905 insertions(+), 743 deletions(-) create mode 100644 pilot/scene/chat_db/auto_execute/__init__.py create mode 100644 pilot/scene/chat_db/auto_execute/chat.py rename pilot/scene/chat_db/{ => auto_execute}/out_parser.py (90%) rename pilot/scene/chat_db/{ => auto_execute}/prompt.py (63%) delete mode 100644 pilot/scene/chat_db/chat.py create mode 100644 pilot/scene/chat_db/professional_qa/__init__.py create mode 100644 pilot/scene/chat_db/professional_qa/chat.py create mode 100644 pilot/scene/chat_db/professional_qa/out_parser.py create mode 100644 pilot/scene/chat_db/professional_qa/prompt.py delete mode 100644 pilot/scene/chat_execution/prompt_with_command.py create mode 100644 pilot/scene/chat_knowledge/custom/chat.py create mode 100644 pilot/scene/chat_knowledge/custom/out_parser.py create mode 100644 pilot/scene/chat_knowledge/custom/prompt.py create mode 100644 pilot/scene/chat_knowledge/default/chat.py create mode 100644 pilot/scene/chat_knowledge/default/out_parser.py create mode 100644 pilot/scene/chat_knowledge/default/prompt.py create mode 100644 pilot/scene/chat_knowledge/url/chat.py create mode 100644 pilot/scene/chat_knowledge/url/out_parser.py create mode 100644 pilot/scene/chat_knowledge/url/prompt.py create mode 100644 pilot/scene/chat_normal/chat.py create mode 100644 pilot/scene/chat_normal/out_parser.py diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 2b8d6fe4b..c3ac5bdc6 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -277,6 +277,7 @@ class Database: def run(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" + print("sql run:" + command) cursor = session.execute(text(command)) if cursor.returns_rows: if fetch == "all": diff --git a/pilot/conversation.py b/pilot/conversation.py index ee07adf33..8a758dd51 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -105,18 +105,14 @@ class Conversation: } -def gen_sqlgen_conversation(dbname): - from pilot.connections.mysql import MySQLOperator - - mo = MySQLOperator(**(DB_SETTINGS)) - - message = "" - - schemas = mo.get_schema(dbname) - for s in schemas: - message += s["schema_info"] + ";" - return f"Database {dbname} Schema information as follows: {message}\n" - +conv_default = Conversation( + system = None, + roles=("human", "ai"), + messages= (), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) conv_one_shot = Conversation( system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " @@ -261,7 +257,7 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回 # question: # {question} # """ -default_conversation = conv_one_shot +default_conversation = conv_default chat_mode_title = { @@ -289,8 +285,4 @@ conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1, "auto_dbgpt_one_shot": auto_dbgpt_one_shot, -} - -if __name__ == "__main__": - message = gen_sqlgen_conversation("dbgpt") - print(message) +} \ No newline at end of file diff --git a/pilot/model/proxy_llm.py b/pilot/model/proxy_llm.py index 3242603d3..92887cfc6 100644 --- a/pilot/model/proxy_llm.py +++ b/pilot/model/proxy_llm.py @@ -21,22 +21,46 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) } messages = prompt.split(stop) - # Add history conversation - for i in range(1, len(messages) - 2, 2): - history.append( - {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, - ) - history.append( - { - "role": "system", - "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1], - } - ) + for message in messages: + if len(message) <= 0: + continue + if "human:" in message: + history.append( + {"role": "user", "content": message.split("human:")[1]}, + ) + elif "system:" in message: + history.append( + { + "role": "system", + "content": message.split("system:")[1], + } + ) + elif "ai:" in message: + history.append( + { + "role": "ai", + "content": message.split("ai:")[1], + } + ) + else: + history.append( + { + "role": "system", + "content": message, + } + ) + + # 把最后一个用户的信息移动到末尾 + temp_his = history[::-1] + last_user_input = None + for m in temp_his: + if m["role"] == "user": + last_user_input = m + if last_user_input: + history.remove(last_user_input) + history.append(last_user_input) - # Add user query - query = messages[-2].split(ROLE_USER + ":")[1] - history.append({"role": "user", "content": query}) payloads = { "model": "gpt-3.5-turbo", # just for test, remove this later "messages": history, diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 57f0a7f7e..fee3eda37 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -36,7 +36,7 @@ class BaseOutputParser(ABC): self.sep = sep self.is_stream_out = is_stream_out - def __post_process_code(code): + def __post_process_code(self, code): sep = "\n```" if sep in code: blocks = code.split(sep) @@ -92,7 +92,7 @@ class BaseOutputParser(ABC): ai_response = ai_response.replace("\n", "") ai_response = ai_response.replace("\_", "_") ai_response = ai_response.replace("\*", "*") - print("un_stream clear response:{}", ai_response) + print("un_stream ai response:", ai_response) return ai_response else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) @@ -140,6 +140,7 @@ class BaseOutputParser(ABC): cleaned_output = m.group(0) else: raise ValueError("model server out not fllow the prompt!") + cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') return cleaned_output def parse_view_response(self, ai_text) -> str: diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 389b1a33e..888b6f81e 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -31,15 +31,15 @@ DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { class PromptTemplate(BaseModel, ABC): input_variables: List[str] """A list of the names of the variables the prompt template expects.""" - template_scene: str + template_scene: Optional[str] - template_define: str + template_define: Optional[str] """this template define""" - template: str + template: Optional[str] """The prompt template.""" template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - response_format: str + response_format: Optional[str] """default use stream out""" stream_out: bool = True """""" @@ -57,52 +57,12 @@ class PromptTemplate(BaseModel, ABC): """Return the prompt type key.""" return "prompt" - def _generate_command_string(self, command: Dict[str, Any]) -> str: - """ - Generate a formatted string representation of a command. - - Args: - command (dict): A dictionary containing command information. - - Returns: - str: The formatted command string. - """ - args_string = ", ".join( - f'"{key}": "{value}"' for key, value in command["args"].items() - ) - return f'{command["label"]}: "{command["name"]}", args: {args_string}' - - def _generate_numbered_list(self, items: List[Any], item_type="list") -> str: - """ - Generate a numbered list from given items based on the item_type. - - Args: - items (list): A list of items to be numbered. - item_type (str, optional): The type of items in the list. - Defaults to 'list'. - - Returns: - str: The formatted numbered list. - """ - if item_type == "command": - command_strings = [] - if self.command_registry: - command_strings += [ - str(item) - for item in self.command_registry.commands.values() - if item.enabled - ] - # terminate command is added manually - command_strings += [self._generate_command_string(item) for item in items] - return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) - else: - return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) - def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs.""" - - kwargs["response"] = json.dumps(self.response_format, indent=4) - return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + if self.template: + if self.response_format: + kwargs["response"] = json.dumps(self.response_format, indent=4) + return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) def add_goals(self, goal: str) -> None: self.goals.append(goal) diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 9fcc6fb31..21f605fed 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -2,8 +2,10 @@ from enum import Enum class ChatScene(Enum): - ChatWithDb = "chat_with_db" + ChatWithDbExecute = "chat_with_db_execute" + ChatWithDbQA = "chat_with_db_qa" ChatExecution = "chat_execution" ChatKnowledge = "chat_default_knowledge" ChatNewKnowledge = "chat_new_knowledge" + ChatUrlKnowledge = "chat_url_knowledge" ChatNormal = "chat_normal" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 1398a476a..798e071e3 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -56,7 +56,7 @@ class BaseChat(ABC): arbitrary_types_allowed = True - def __init__(self, chat_mode, chat_session_id, current_user_input): + def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input): self.chat_session_id = chat_session_id self.chat_mode = chat_mode self.current_user_input: str = current_user_input @@ -64,12 +64,12 @@ class BaseChat(ABC): ### TODO self.memory = FileHistoryMemory(chat_session_id) ### load prompt template - self.prompt_template: PromptTemplate = CFG.prompt_templates[ - self.chat_mode.value - ] + self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] self.history_message: List[OnceConversation] = [] self.current_message: OnceConversation = OnceConversation() self.current_tokens_used: int = 0 + self.temperature = temperature + self.max_new_tokens = max_new_tokens ### load chat_session_id's chat historys self._load_history(self.chat_session_id) @@ -92,15 +92,17 @@ class BaseChat(ABC): pass def __call_base(self): - input_values = self.generate_input_values() + input_values = self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) self.current_message.start_date = datetime.datetime.now() # TODO self.current_message.tokens = 0 + current_prompt = None - current_prompt = self.prompt_template.format(**input_values) + if self.prompt_template.template: + current_prompt = self.prompt_template.format(**input_values) ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 if self.history_message: @@ -108,8 +110,8 @@ class BaseChat(ABC): logger.info( f"There are already {len(self.history_message)} rounds of conversations!" ) - - self.current_message.add_system_message(current_prompt) + if current_prompt: + self.current_message.add_system_message(current_prompt) payload = { "model": self.llm_model, @@ -118,7 +120,6 @@ class BaseChat(ABC): "max_new_tokens": int(self.max_new_tokens), "stop": self.prompt_template.sep, } - logger.info(f"Requert: \n{payload}") return payload def stream_call(self): @@ -127,30 +128,18 @@ class BaseChat(ABC): ai_response_text = "" try: show_info = "" + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + timeout=120, + ) - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate_stream"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # - # ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) + ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) - # for resp_text_trunck in ai_response_text: - # show_info = resp_text_trunck - # yield resp_text_trunck + "▌" - # - - #### MOCK TEST - def mock_stream_out(): - for i in range(1, 11): - time.sleep(0.5) - yield f"Message:{i}" - - for msg in mock_stream_out(): - show_info = msg - yield msg + "▌" + for resp_text_trunck in ai_response_text: + show_info = resp_text_trunck + yield resp_text_trunck + "▌" self.current_message.add_ai_message(show_info) @@ -186,13 +175,13 @@ class BaseChat(ABC): result = self.do_with_prompt_response(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): - if prompt_define_response.thoughts.get("speak"): + if hasattr(prompt_define_response.thoughts, "speak"): self.current_message.add_view_message( self.prompt_template.output_parser.parse_view_response( prompt_define_response.thoughts.get("speak"), result ) ) - elif prompt_define_response.thoughts.get("reasoning"): + elif hasattr(prompt_define_response.thoughts, "reasoning"): self.current_message.add_view_message( self.prompt_template.output_parser.parse_view_response( prompt_define_response.thoughts.get("reasoning"), result @@ -223,15 +212,18 @@ class BaseChat(ABC): def call(self): if self.prompt_template.stream_out: - yield self.stream_call() + yield self.stream_call() else: return self.nostream_call() def generate_llm_text(self) -> str: - text = self.prompt_template.template_define + self.prompt_template.sep - ### 线处理历史信息 + text = "" + if self.prompt_template.template_define: + text = self.prompt_template.template_define + self.prompt_template.sep + + ### 处理历史信息 if len(self.history_message) > self.chat_retention_rounds: - ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 + ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 for first_message in self.history_message[0].messages: if not isinstance(first_message, ViewMessage): text += ( @@ -262,8 +254,8 @@ class BaseChat(ABC): + message.content + self.prompt_template.sep ) - ### current conversation + for now_message in self.current_message.messages: text += ( now_message.type + ":" + now_message.content + self.prompt_template.sep @@ -298,34 +290,3 @@ class BaseChat(ABC): """ pass - -if __name__ == "__main__": - # - # def call_back(t, m): - # print(t) - # print(m) - # - # def my_fn(call_fn, xx): - # call_fn(1, xx) - # - # - # my_fn(call_back, "1231") - - def my_generator(): - while True: - value = yield - print('Received value:', value) - if value == 'stop': - return - - - # 创建生成器对象 - gen = my_generator() - - # 启动生成器 - next(gen) - - # 发送数据到生成器 - gen.send('Hello') - gen.send('World') - gen.send('stop') diff --git a/pilot/scene/chat_db/auto_execute/__init__.py b/pilot/scene/chat_db/auto_execute/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py new file mode 100644 index 000000000..2b8918fde --- /dev/null +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -0,0 +1,57 @@ +import json + +from pilot.scene.base_message import ( + HumanMessage, + ViewMessage, +) +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config +from pilot.common.markdown_text import ( + generate_htm_table, +) +from pilot.scene.chat_db.auto_execute.prompt import prompt + +CFG = Config() + + +class ChatWithDbAutoExecute(BaseChat): + chat_scene: str = ChatScene.ChatWithDbExecute.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatWithDbExecute, + chat_session_id=chat_session_id, + current_user_input=user_input) + if not db_name: + raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!") + self.db_name = db_name + self.database = CFG.local_db + # 准备DB信息(拿到指定库的链接) + self.db_connect = self.database.get_session(self.db_name) + self.top_k: int = 5 + + def generate_input_values(self): + input_values = { + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect) + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return self.database.run(self.db_connect, prompt_response.sql) + + + +if __name__ == "__main__": + ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}" + ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '') + print(ss) + json.loads(ss) \ No newline at end of file diff --git a/pilot/scene/chat_db/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py similarity index 90% rename from pilot/scene/chat_db/out_parser.py rename to pilot/scene/chat_db/auto_execute/out_parser.py index 307aff680..cb059feb8 100644 --- a/pilot/scene/chat_db/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -22,7 +22,9 @@ class DbChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text): - response = json.loads(super().parse_prompt_response(model_out_text)) + clean_str = super().parse_prompt_response(model_out_text); + print("clean prompt response:", clean_str) + response = json.loads(clean_str) sql, thoughts = response["sql"], response["thoughts"] return SqlAction(sql, thoughts) diff --git a/pilot/scene/chat_db/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py similarity index 63% rename from pilot/scene/chat_db/prompt.py rename to pilot/scene/chat_db/auto_execute/prompt.py index aeaf994c0..9a381345f 100644 --- a/pilot/scene/chat_db/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -1,36 +1,30 @@ import json +import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_db.out_parser import DbChatOutputParser, SqlAction +from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction from pilot.common.schema import SeparatorStyle CFG = Config() PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" -PROMPT_SUFFIX = """Only use the following tables: -{table_info} - -Question: {input} - -""" _DEFAULT_TEMPLATE = """ You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. +If the given table is beyond the scope of use, do not use it forcibly. Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. """ -_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question. -Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database. -Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers. -Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. -Pay attention to use CURDATE() function to get the current date, if the question involves "today". +PROMPT_SUFFIX = """Only use the following tables: +{table_info} +Question: {input} """ @@ -49,17 +43,16 @@ RESPONSE_FORMAT = { } RESPONSE_FORMAT_SIMPLE = { - "thoughts": "thoughts summary to say to user", + "thoughts": "thoughts summary to say to user", "sql": "SQL Query to run", } - PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_NEED_NEED_STREAM_OUT = False -chat_db_prompt = PromptTemplate( - template_scene=ChatScene.ChatWithDb.value, +prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDbExecute.value, input_variables=["input", "table_info", "dialect", "top_k", "response"], response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4), template_define=PROMPT_SCENE_DEFINE, @@ -69,5 +62,5 @@ chat_db_prompt = PromptTemplate( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), ) +CFG.prompt_templates.update({prompt.template_scene: prompt}) -CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt}) diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py deleted file mode 100644 index 745e9804d..000000000 --- a/pilot/scene/chat_db/chat.py +++ /dev/null @@ -1,240 +0,0 @@ -import requests -import datetime -import threading -import json -import traceback -from urllib.parse import urljoin -from sqlalchemy import ( - MetaData, - Table, - create_engine, - inspect, - select, - text, -) -from typing import Any, Iterable, List, Optional - -from pilot.scene.base_message import ( - BaseMessage, - SystemMessage, - HumanMessage, - AIMessage, - ViewMessage, -) -from pilot.scene.base_chat import BaseChat, logger, headers -from pilot.scene.base import ChatScene -from pilot.common.sql_database import Database -from pilot.configs.config import Config -from pilot.scene.chat_db.out_parser import SqlAction -from pilot.configs.model_config import LOGDIR, DATASETS_DIR -from pilot.utils import ( - build_logger, - server_error_msg, -) -from pilot.common.markdown_text import ( - generate_markdown_table, - generate_htm_table, - datas_to_table_html, -) -from pilot.scene.chat_db.prompt import chat_db_prompt -from pilot.out_parser.base import BaseOutputParser -from pilot.scene.chat_db.out_parser import DbChatOutputParser - -CFG = Config() - - -class ChatWithDb(BaseChat): - chat_scene: str = ChatScene.ChatWithDb.value - - """Number of results to return from the query""" - - def __init__(self, chat_session_id, db_name, user_input): - """ """ - super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input) - if not db_name: - raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!") - self.db_name = db_name - self.database = CFG.local_db - # 准备DB信息(拿到指定库的链接) - self.db_connect = self.database.get_session(self.db_name) - self.top_k: int = 5 - - def generate_input_values(self): - input_values = { - "input": self.current_user_input, - "top_k": str(self.top_k), - "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect) - } - return input_values - - def do_with_prompt_response(self, prompt_response): - return self.database.run(self.db_connect, prompt_response.sql) - - # def call(self) -> str: - # input_values = { - # "input": self.current_user_input, - # "top_k": str(self.top_k), - # "dialect": self.database.dialect, - # "table_info": self.database.table_simple_info(self.db_connect), - # # "stop": self.sep_style, - # } - # - # ### Chat sequence advance - # self.current_message.chat_order = len(self.history_message) + 1 - # self.current_message.add_user_message(self.current_user_input) - # self.current_message.start_date = datetime.datetime.now() - # # TODO - # self.current_message.tokens = 0 - # - # current_prompt = self.prompt_template.format(**input_values) - # - # ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - # if self.history_message: - # ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - # logger.info( - # f"There are already {len(self.history_message)} rounds of conversations!" - # ) - # - # self.current_message.add_system_message(current_prompt) - # - # payload = { - # "model": self.llm_model, - # "prompt": self.generate_llm_text(), - # "temperature": float(self.temperature), - # "max_new_tokens": int(self.max_new_tokens), - # "stop": self.prompt_template.sep, - # } - # logger.info(f"Requert: \n{payload}") - # ai_response_text = "" - # try: - # ### 走非流式的模型服务接口 - # - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # ai_response_text = ( - # self.prompt_template.output_parser.parse_model_server_out(response) - # ) - # self.current_message.add_ai_message(ai_response_text) - # prompt_define_response = ( - # self.prompt_template.output_parser.parse_prompt_response( - # ai_response_text - # ) - # ) - # - # result = self.database.run(self.db_connect, prompt_define_response.sql) - # - # if hasattr(prompt_define_response, "thoughts"): - # if prompt_define_response.thoughts.get("speak"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("speak"), result - # ) - # ) - # elif prompt_define_response.thoughts.get("reasoning"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("reasoning"), result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts, result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response, result - # ) - # ) - # - # except Exception as e: - # print(traceback.format_exc()) - # logger.error("model response parase faild!" + str(e)) - # self.current_message.add_view_message( - # f"""ERROR!{str(e)}\n {ai_response_text} """ - # ) - # ### 对话记录存储 - # self.memory.append(self.current_message) - - def chat_show(self): - ret = [] - # 单论对话只能有一次User 记录 和一次 AI 记录 - # TODO 推理过程前端展示。。。 - for message in self.current_message.messages: - if isinstance(message, HumanMessage): - ret[-1][-2] = message.content - # 是否展示推理过程 - if isinstance(message, ViewMessage): - ret[-1][-1] = message.content - - return ret - - @property - def chat_type(self) -> str: - return ChatScene.ChatExecution.value - - -if __name__ == "__main__": - # chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息") - # - # chat.call() - # - # resp = chat.chat_show() - # - # print(vars(resp)) - - # memory = FileHistoryMemory("test123") - # once1 = OnceConversation() - # once1.add_user_message("问题测试") - # once1.add_system_message("prompt1") - # once1.add_system_message("prompt2") - # once1.chat_order = 1 - # once1.set_start_time(datetime.datetime.now()) - # memory.append(once1) - # - # once = OnceConversation() - # once.add_user_message("问题测试2") - # once.add_system_message("prompt3") - # once.add_system_message("prompt4") - # once.chat_order = 2 - # once.set_start_time(datetime.datetime.now()) - # memory.append(once) - - db: Database = CFG.local_db - db_connect = db.get_session("gpt-user") - data = db.run(db_connect, "select * from users") - print(generate_htm_table(data)) - - # - # print(db.run(db_connect, "select * from users")) - # - # # - # # def print_numbers(): - # # db_connect1 = db.get_session("dbgpt-test") - # # cursor1 = db_connect1.execute(text("select * from test_name")) - # # if cursor1.returns_rows: - # # result1 = cursor1.fetchall() - # # print( result1) - # # - # # - # # # 创建线程 - # # t = threading.Thread(target=print_numbers) - # # # 启动线程 - # # t.start() - # - # print(db.run(db_connect, "select * from tran_order")) - # - # print(db.run(db_connect, "select count(*) as aa from tran_order")) - # - # print(db.table_simple_info(db_connect)) - # my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] - # index = 3 - # last_three_elements = my_list[-index:] - # print(last_three_elements) diff --git a/pilot/scene/chat_db/professional_qa/__init__.py b/pilot/scene/chat_db/professional_qa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py new file mode 100644 index 000000000..fbf5a8bb4 --- /dev/null +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -0,0 +1,56 @@ +from pilot.scene.base_message import ( + HumanMessage, + ViewMessage, +) +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config +from pilot.common.markdown_text import ( + generate_htm_table, +) +from pilot.scene.chat_db.professional_qa.prompt import prompt + +CFG = Config() + + +class ChatWithDbQA(BaseChat): + chat_scene: str = ChatScene.ChatWithDbQA.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatWithDbQA, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.db_name = db_name + if db_name: + self.database = CFG.local_db + # 准备DB信息(拿到指定库的链接) + self.db_connect = self.database.get_session(self.db_name) + self.top_k: int = 5 + + def generate_input_values(self): + + table_info = "" + dialect = "mysql" + if self.db_name: + table_info = self.database.table_simple_info(self.db_connect) + dialect = self.database.dialect + + input_values = { + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": dialect, + "table_info": table_info + } + return input_values + + def do_with_prompt_response(self, prompt_response): + if self.auto_execute: + return self.database.run(self.db_connect, prompt_response.sql) + else: + return prompt_response diff --git a/pilot/scene/chat_db/professional_qa/out_parser.py b/pilot/scene/chat_db/professional_qa/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_db/professional_qa/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_db/professional_qa/prompt.py b/pilot/scene/chat_db/professional_qa/prompt.py new file mode 100644 index 000000000..00fc87c03 --- /dev/null +++ b/pilot/scene/chat_db/professional_qa/prompt.py @@ -0,0 +1,48 @@ +import json +import importlib +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser +from pilot.common.schema import SeparatorStyle + +CFG = Config() + +PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """ + +PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: +{table_info} + +Question: {input} + +""" + +_DEFAULT_TEMPLATE = """ +You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. +Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDbQA.value, + input_variables=["input", "table_info", "dialect", "top_k"], + response_format=None, + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX , + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 210b2ad77..27e79e3a5 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -10,8 +10,7 @@ from pilot.scene.base import ChatScene from pilot.configs.config import Config from pilot.commands.command import execute_command from pilot.prompts.generator import PluginPromptGenerator - -from pilot.scene.chat_execution.prompt import chat_plugin_prompt +from pilot.scene.chat_execution.prompt import prompt CFG = Config() @@ -20,8 +19,12 @@ class ChatWithPlugin(BaseChat): plugins_prompt_generator:PluginPromptGenerator select_plugin: str = None - def __init__(self, chat_session_id, user_input, plugin_selector:str=None): - super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input) + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None): + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatExecution, + chat_session_id=chat_session_id, + current_user_input=user_input) self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator.command_registry = CFG.command_registry # 加载插件中可用命令 diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index f3f9e683e..ff5b6a0d7 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -20,8 +20,8 @@ class PluginChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: response = json.loads(super().parse_prompt_response(model_out_text)) - sql, thoughts = response["command"], response["thoughts"] - return PluginAction(sql, thoughts) + command, thoughts = response["command"], response["thoughts"] + return PluginAction(command, thoughts) def parse_view_response(self, ai_text) -> str: return super().parse_view_response(ai_text) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index 44f564afe..6875689cf 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -1,4 +1,5 @@ import json +import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene @@ -50,7 +51,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value ### Whether the model service is streaming output PROMPT_NEED_NEED_STREAM_OUT = False -chat_plugin_prompt = PromptTemplate( +prompt = PromptTemplate( template_scene=ChatScene.ChatExecution.value, input_variables=["input", "constraints", "commands_infos", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), @@ -62,4 +63,4 @@ chat_plugin_prompt = PromptTemplate( ), ) -CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt}) +CFG.prompt_templates.update({prompt.template_scene: prompt}) \ No newline at end of file diff --git a/pilot/scene/chat_execution/prompt_with_command.py b/pilot/scene/chat_execution/prompt_with_command.py deleted file mode 100644 index e3469d7c2..000000000 --- a/pilot/scene/chat_execution/prompt_with_command.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -from pilot.prompts.prompt_new import PromptTemplate -from pilot.configs.config import Config -from pilot.scene.base import ChatScene -from pilot.common.schema import SeparatorStyle - -from pilot.scene.chat_execution.out_parser import PluginChatOutputParser - - -CFG = Config() - -PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications.""" - -PROMPT_SUFFIX = """ -Goals: - {input} - -""" - -_DEFAULT_TEMPLATE = """ -Constraints: - Exclusively use the commands listed in double quotes e.g. "command name" - Reflect on past decisions and strategies to refine your approach. - Constructively self-criticize your big-picture behavior constantly. - {constraints} - -Commands: - {commands_infos} -""" - - -PROMPT_RESPONSE = """You must respond in JSON format as following format: -{response} - -Ensure the response is correct json and can be parsed by Python json.loads -""" - -RESPONSE_FORMAT = { - "thoughts": { - "text": "thought", - "reasoning": "reasoning", - "plan": "- short bulleted\n- list that conveys\n- long-term plan", - "criticism": "constructive self-criticism", - "speak": "thoughts summary to say to user", - }, - "command": {"name": "command name", "args": {"arg name": "value"}}, -} - -PROMPT_SEP = SeparatorStyle.SINGLE.value -### Whether the model service is streaming output -PROMPT_NEED_NEED_STREAM_OUT = False - -chat_plugin_prompt = PromptTemplate( - template_scene=ChatScene.ChatExecution.value, - input_variables=["input", "table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT, indent=4), - template_define=PROMPT_SCENE_DEFINE, - template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, - stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=PluginChatOutputParser( - sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT - ), -) - -CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt}) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 97c547390..7a346cbda 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -1,8 +1,14 @@ from pilot.scene.base_chat import BaseChat from pilot.singleton import Singleton -from pilot.scene.chat_db.chat import ChatWithDb +import inspect +import importlib from pilot.scene.chat_execution.chat import ChatWithPlugin - +from pilot.scene.chat_normal.chat import ChatNormal +from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA +from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute +from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge +from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge +from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge class ChatFactory(metaclass=Singleton): @staticmethod @@ -13,5 +19,5 @@ class ChatFactory(metaclass=Singleton): if cls.chat_scene == chat_mode: implementation = cls(**kwargs) if implementation == None: - raise Exception("Invalid implementation name:" + chat_mode) + raise Exception(f"Invalid implementation name:{chat_mode}") return implementation diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py new file mode 100644 index 000000000..7b9a11f85 --- /dev/null +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -0,0 +1,69 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) + +from pilot.scene.chat_normal.prompt import prompt +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatNewKnowledge (BaseChat): + chat_scene: str = ChatScene.ChatNewKnowledge.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatNewKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.knowledge_name = knowledge_name + vector_store_config = { + "vector_store_name": knowledge_name, + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + + def generate_input_values(self): + docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = docs[:2000] + input_values = { + "context": docs, + "question": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatNewKnowledge.value diff --git a/pilot/scene/chat_knowledge/custom/out_parser.py b/pilot/scene/chat_knowledge/custom/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_knowledge/custom/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py new file mode 100644 index 000000000..175deaddb --- /dev/null +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -0,0 +1,43 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatNewKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py new file mode 100644 index 000000000..978570d91 --- /dev/null +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -0,0 +1,66 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) + +from pilot.scene.chat_normal.prompt import prompt +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatDefaultKnowledge (BaseChat): + chat_scene: str = ChatScene.ChatKnowledge.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input) + vector_store_config = { + "vector_store_name": "default", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + def generate_input_values(self): + docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = docs[:2000] + input_values = { + "context": docs, + "question": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatKnowledge.value diff --git a/pilot/scene/chat_knowledge/default/out_parser.py b/pilot/scene/chat_knowledge/default/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_knowledge/default/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py new file mode 100644 index 000000000..d2ba473ab --- /dev/null +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -0,0 +1,43 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py new file mode 100644 index 000000000..0c54f6001 --- /dev/null +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -0,0 +1,71 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) + +from pilot.scene.chat_normal.prompt import prompt +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatUrlKnowledge (BaseChat): + chat_scene: str = ChatScene.ChatUrlKnowledge.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatUrlKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.url = url + vector_store_config = { + "vector_store_name": url, + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + file_path=url, + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + + # url soruce in vector + self.knowledge_embedding_client.knowledge_embedding() + + def generate_input_values(self): + docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) + docs = docs[:2000] + input_values = { + "context": docs, + "question": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatUrlKnowledge.value diff --git a/pilot/scene/chat_knowledge/url/out_parser.py b/pilot/scene/chat_knowledge/url/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_knowledge/url/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py new file mode 100644 index 000000000..a5c1fe226 --- /dev/null +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -0,0 +1,43 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatUrlKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=None, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py new file mode 100644 index 000000000..edd6ac53c --- /dev/null +++ b/pilot/scene/chat_normal/chat.py @@ -0,0 +1,43 @@ + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) +from pilot.scene.chat_normal.prompt import prompt + +CFG = Config() + + +class ChatNormal(BaseChat): + chat_scene: str = ChatScene.ChatNormal.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.ChatNormal, + chat_session_id=chat_session_id, + current_user_input=user_input) + + def generate_input_values(self): + input_values = { + "input": self.current_user_input + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatNormal.value diff --git a/pilot/scene/chat_normal/out_parser.py b/pilot/scene/chat_normal/out_parser.py new file mode 100644 index 000000000..0f7ccd791 --- /dev/null +++ b/pilot/scene/chat_normal/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return super().parse_view_response(ai_text) + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index fd21f2102..7a11cf3ff 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -1,31 +1,33 @@ import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser -def stream_write_and_read(lst): - # 对lst使用yield from进行可迭代对象的扁平化 - yield from lst - while True: - val = yield - lst.append(val) +CFG = Config() + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatNormal.value, + input_variables=["input"], + response_format=None, + template_define=None, + template=None, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) -if __name__ == "__main__": - # 创建一个空列表 - my_list = [] +CFG.prompt_templates.update({prompt.template_scene: prompt}) - # 使用生成器写入数据 - stream_writer = stream_write_and_read(my_list) - next(stream_writer) - stream_writer.send(10) - print(1) - stream_writer.send(20) - print(2) - stream_writer.send(30) - print(3) - # 使用生成器读取数据 - stream_reader = stream_write_and_read(my_list) - next(stream_reader) - print(stream_reader.send(None)) - print(stream_reader.send(None)) - print(stream_reader.send(None)) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 2e8f61016..515701255 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - +import traceback import argparse import datetime import json @@ -9,7 +9,6 @@ import shutil import sys import time import uuid -from urllib.parse import urljoin import gradio as gr import requests @@ -216,19 +215,26 @@ def post_process_code(code): return code -def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: +def get_chat_mode(selected, param=None) -> ChatScene: if chat_mode_title['chat_use_plugin'] == selected: return ChatScene.ChatExecution elif chat_mode_title['knowledge_qa'] == selected: + mode= param if mode == conversation_types["default_knownledge"]: return ChatScene.ChatKnowledge elif mode == conversation_types["custome"]: return ChatScene.ChatNewKnowledge + elif mode == conversation_types["url"]: + return ChatScene.ChatUrlKnowledge + else: + return ChatScene.ChatNormal else: - if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: - return ChatScene.ChatWithDb + sql_mode= param + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + return ChatScene.ChatWithDbExecute + else: + return ChatScene.ChatWithDbQA - return ChatScene.ChatNormal def chatbot_callback(state, message): print(f"chatbot_callback:{message}") @@ -237,244 +243,99 @@ def chatbot_callback(state, message): def http_bot( - state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request + state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name ): - logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}") - start_tstamp = time.time() - scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector) - print(f"now chat scene:{scene.value}") - model_name = CFG.LLM_MODEL - if ChatScene.ChatWithDb == scene: - logger.info("chat with db mode use new architecture design!") + logger.info(f"User message send!{state.conv_id},{selected}") + if chat_mode_title['knowledge_qa'] == selected: + scene: ChatScene = get_chat_mode(selected, mode) + elif chat_mode_title['chat_use_plugin'] == selected: + scene: ChatScene = get_chat_mode(selected) + else: + scene: ChatScene = get_chat_mode(selected, sql_mode) + print(f"chat scene:{scene.value}") + + if ChatScene.ChatWithDbExecute == scene: chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "db_name": db_selector, + "user_input": state.last_user_input + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatWithDbQA == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "db_name": db_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - chat.call() - - state.messages[-1][-1] = chat.current_ai_response() - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - elif ChatScene.ChatExecution == scene: - logger.info("plugin mode use new architecture design!") chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, "chat_session_id": state.conv_id, "plugin_selector": plugin_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - strem_generate = chat.stream_call() + elif ChatScene.ChatNormal == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatKnowledge == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatNewKnowledge == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + "knowledge_name": knowledge_name + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + elif ChatScene.ChatUrlKnowledge == scene: + chat_param = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "chat_session_id": state.conv_id, + "user_input": state.last_user_input, + "url": url_input + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) - for msg in strem_generate: - state.messages[-1][-1] = msg + if not chat.prompt_template.stream_out: + logger.info("not stream out, wait model response!") + state.messages[-1][-1] = chat.nostream_call() + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + else: + logger.info("stream out start!") + try: + stream_gen = chat.stream_call() + for msg in stream_gen: + state.messages[-1][-1] = msg + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + except Exception as e: + print(traceback.format_exc()) + state.messages[-1][-1] = "Error:" + str(e) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - # def generate_numbers(): - # for i in range(10): - # time.sleep(0.5) - # yield f"Message:{i}" - # - # def showMessage(message): - # return message - # - # for n in generate_numbers(): - # state.messages[-1][-1] = n + "▌" - # yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - else: - - dbname = db_selector - # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 - if state.skip_next: - # This generate call is skipped due to invalid inputs - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - return - - if len(state.messages) == state.offset + 2: - query = state.messages[-2][1] - - template_name = "conv_one_shot" - new_state = conv_templates[template_name].copy() - # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? - # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 - if db_selector: - new_state.append_message( - new_state.roles[0], gen_sqlgen_conversation(dbname) + query - ) - new_state.append_message(new_state.roles[1], None) - else: - new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - - new_state.conv_id = uuid.uuid4().hex - state = new_state - else: - ### 后续对话 - query = state.messages[-2][1] - # 第一轮对话需要加入提示Prompt - if mode == conversation_types["custome"]: - template_name = "conv_one_shot" - new_state = conv_templates[template_name].copy() - # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? - # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 - if db_selector: - new_state.append_message( - new_state.roles[0], gen_sqlgen_conversation(dbname) + query - ) - new_state.append_message(new_state.roles[1], None) - else: - new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - state = new_state - - prompt = state.get_prompt() - skip_echo_len = len(prompt.replace("", " ")) + 1 - if mode == conversation_types["default_knownledge"] and not db_selector: - vector_store_config = { - "vector_store_name": "default", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - if mode == conversation_types["custome"] and not db_selector: - print("vector store name: ", vector_store_name["vs_name"]) - vector_store_config = { - "vector_store_name": vector_store_name["vs_name"], - "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - if mode == conversation_types["url"] and url_input: - print("url: ", url_input) - vector_store_config = { - "vector_store_name": url_input, - "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path=url_input, - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - # Make requests - payload = { - "model": model_name, - "prompt": prompt, - "temperature": float(temperature), - "max_new_tokens": int(max_new_tokens), - "stop": state.sep - if state.sep_style == SeparatorStyle.SINGLE - else state.sep2, - } - logger.info(f"Requert: \n{payload}") - - # 流式输出 - state.messages[-1][-1] = "▌" - yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 - - try: - # Stream output - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) - - """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. - """ - if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: - output = data["text"][skip_echo_len:].strip() - else: - output = data["text"].strip() - - output = post_process_code(output) - state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - ) * 5 - else: - output = ( - data["text"] + f" (error_code: {data['error_code']})" - ) - state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - except requests.exceptions.RequestException as e: - state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - state.messages[-1][-1] = state.messages[-1][-1][:-1] - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - - # 记录运行日志 - finish_tstamp = time.time() - logger.info(f"{output}") - - with open(get_conv_log_filename(), "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "start": round(start_tstamp, 4), - "finish": round(start_tstamp, 4), - "state": state.dict(), - "ip": request.client.host, - } - fout.write(json.dumps(data) + "\n") - + if state.messages[-1][-1].endwith("▌"): + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 block_css = ( code_highlight_css @@ -556,6 +417,7 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True, + name="db_selector" ).style(container=False) sql_mode = gr.Radio( @@ -565,6 +427,7 @@ def build_single_model_ui(): ], show_label=False, value=get_lang_text("sql_generate_mode_none"), + name="sql_mode" ) sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) @@ -581,7 +444,8 @@ def build_single_model_ui(): value="", interactive=True, show_label=True, - type="value" + type="value", + name="plugin_selector" ).style(container=False) def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData @@ -602,13 +466,14 @@ def build_single_model_ui(): ], show_label=False, value=llm_native_dialogue, + name="mode" ) vs_setting = gr.Accordion( - get_lang_text("configure_knowledge_base"), open=False + get_lang_text("configure_knowledge_base"), open=False, visible=False ) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) - url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True) + url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False, name="url_input") def show_url_input(evt:gr.SelectData): if evt.value == url_knowledge_dialogue: return gr.update(visible=True) @@ -619,7 +484,7 @@ def build_single_model_ui(): with vs_setting: vs_name = gr.Textbox( - label=get_lang_text("new_klg_name"), lines=1, interactive=True + label=get_lang_text("new_klg_name"), lines=1, interactive=True, name = "vs_name" ) vs_add = gr.Button(get_lang_text("add_as_new_klg")) with gr.Column() as doc2vec: @@ -664,10 +529,14 @@ def build_single_model_ui(): clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False) gr.Markdown(learn_more_markdown) + + params = [plugin_selector, mode, sql_mode, db_selector, url_input, vs_name] + + btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], + [state, selected, temperature, max_output_tokens] + params, [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -676,7 +545,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], + [state, selected, temperature, max_output_tokens]+ params, [state, chatbot] + btn_list, ) @@ -684,7 +553,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens], + [state, selected, temperature, max_output_tokens]+ params, [state, chatbot] + btn_list, ) vs_add.click( diff --git a/requirements.txt b/requirements.txt index f476a4b23..b2f582eed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ fonttools==4.38.0 frozenlist==1.3.3 huggingface-hub==0.13.4 importlib-resources==5.12.0 + kiwisolver==1.4.4 matplotlib==3.7.0 multidict==6.0.4 From 3a46dfd3c2d152419eb85a6a5095efa09302789a Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Wed, 31 May 2023 16:26:47 +0800 Subject: [PATCH 6/7] Implemented a new multi-scenario dialogue architecture --- pilot/conversation.py | 4 +-- pilot/memory/chat_history/mem_history.py | 33 ++++++++++++++++++++++ pilot/scene/base_chat.py | 6 +++- pilot/scene/chat_db/auto_execute/prompt.py | 6 ++-- 4 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 pilot/memory/chat_history/mem_history.py diff --git a/pilot/conversation.py b/pilot/conversation.py index 8a758dd51..3fe648529 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -108,8 +108,8 @@ class Conversation: conv_default = Conversation( system = None, roles=("human", "ai"), - messages= (), - offset=2, + messages=[], + offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", ) diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/mem_history.py new file mode 100644 index 000000000..8ff0d08eb --- /dev/null +++ b/pilot/memory/chat_history/mem_history.py @@ -0,0 +1,33 @@ +from typing import List +import json +import os +import datetime +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pathlib import Path + +from pilot.configs.config import Config +from pilot.scene.message import ( + OnceConversation, + conversation_from_dict, + conversations_to_dict, +) + + +CFG = Config() + + +class MemHistoryMemory(BaseChatHistoryMemory): + histroies_map = {} + + def __init__(self, chat_session_id: str): + self.chat_seesion_id = chat_session_id + self.histroies_map.update({chat_session_id: []}) + + def messages(self) -> List[OnceConversation]: + return self.histroies_map.get(self.chat_seesion_id) + + def append(self, once_message: OnceConversation) -> None: + self.histroies_map.get(self.chat_seesion_id).append(once_message) + + def clear(self) -> None: + self.histroies_map.pop(self.chat_seesion_id) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 798e071e3..dce25bec4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -23,6 +23,7 @@ from pilot.scene.message import OnceConversation from pilot.prompts.prompt_new import PromptTemplate from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory +from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.utils import ( @@ -61,7 +62,10 @@ class BaseChat(ABC): self.chat_mode = chat_mode self.current_user_input: str = current_user_input self.llm_model = CFG.LLM_MODEL - ### TODO + ### can configurable storage methods + # self.memory = MemHistoryMemory(chat_session_id) + + ## TEST self.memory = FileHistoryMemory(chat_session_id) ### load prompt template self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 9a381345f..22cb46846 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -14,14 +14,12 @@ PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, pleas _DEFAULT_TEMPLATE = """ You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. -If the given table is beyond the scope of use, do not use it forcibly. +Use as few tables as possible when querying. Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. """ -PROMPT_SUFFIX = """Only use the following tables: +PROMPT_SUFFIX = """Only use the following tables generate sql: {table_info} Question: {input} From ced9b581fc5f324841db90c824beb451d5022e1c Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Wed, 31 May 2023 18:09:20 +0800 Subject: [PATCH 7/7] chat with plugin bug fix --- pilot/scene/base_chat.py | 37 +++++++----------- .../chat_db/professional_qa/out_parser.py | 3 -- pilot/scene/chat_execution/chat.py | 2 +- pilot/scene/chat_execution/out_parser.py | 7 +++- .../scene/chat_knowledge/custom/out_parser.py | 3 -- .../chat_knowledge/default/out_parser.py | 3 -- pilot/scene/chat_knowledge/url/out_parser.py | 3 -- pilot/server/webserver.py | 17 ++++---- plugins/DB-GPT-Plugin-ByteBase.zip | Bin 0 -> 30614 bytes 9 files changed, 27 insertions(+), 48 deletions(-) create mode 100644 plugins/DB-GPT-Plugin-ByteBase.zip diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index dce25bec4..650235d63 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -179,31 +179,22 @@ class BaseChat(ABC): result = self.do_with_prompt_response(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): - if hasattr(prompt_define_response.thoughts, "speak"): - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response.thoughts.get("speak"), result - ) - ) - elif hasattr(prompt_define_response.thoughts, "reasoning"): - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response.thoughts.get("reasoning"), result - ) - ) + if isinstance(prompt_define_response.thoughts, dict): + if "speak" in prompt_define_response.thoughts: + speak_to_user = prompt_define_response.thoughts.get("speak") + else: + speak_to_user = str(prompt_define_response.thoughts) else: - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response.thoughts, result - ) - ) + if hasattr(prompt_define_response.thoughts, "speak"): + speak_to_user = prompt_define_response.thoughts.get("speak") + elif hasattr(prompt_define_response.thoughts, "reasoning"): + speak_to_user = prompt_define_response.thoughts.get("reasoning") + else: + speak_to_user = prompt_define_response.thoughts else: - self.current_message.add_view_message( - self.prompt_template.output_parser.parse_view_response( - prompt_define_response, result - ) - ) - + speak_to_user = prompt_define_response + view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) + self.current_message.add_view_message(view_message) except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) diff --git a/pilot/scene/chat_db/professional_qa/out_parser.py b/pilot/scene/chat_db/professional_qa/out_parser.py index 0f7ccd791..0b8277d63 100644 --- a/pilot/scene/chat_db/professional_qa/out_parser.py +++ b/pilot/scene/chat_db/professional_qa/out_parser.py @@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text - def parse_view_response(self, ai_text) -> str: - return super().parse_view_response(ai_text) - def get_format_instructions(self) -> str: pass diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 27e79e3a5..464df9ba0 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -55,7 +55,7 @@ class ChatWithPlugin(BaseChat): def do_with_prompt_response(self, prompt_response): ## plugin command run - return execute_command(str(prompt_response), self.plugins_prompt_generator) + return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator) def chat_show(self): super().chat_show() diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index ff5b6a0d7..f9796ef3d 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -23,8 +23,11 @@ class PluginChatOutputParser(BaseOutputParser): command, thoughts = response["command"], response["thoughts"] return PluginAction(command, thoughts) - def parse_view_response(self, ai_text) -> str: - return super().parse_view_response(ai_text) + def parse_view_response(self, speak, data) -> str: + ### tool out data to table view + print(f"parse_view_response:{speak},{str(data)}" ) + view_text = f"##### {speak}" + "\n" + str(data) + return view_text def get_format_instructions(self) -> str: pass diff --git a/pilot/scene/chat_knowledge/custom/out_parser.py b/pilot/scene/chat_knowledge/custom/out_parser.py index 0f7ccd791..0b8277d63 100644 --- a/pilot/scene/chat_knowledge/custom/out_parser.py +++ b/pilot/scene/chat_knowledge/custom/out_parser.py @@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text - def parse_view_response(self, ai_text) -> str: - return super().parse_view_response(ai_text) - def get_format_instructions(self) -> str: pass diff --git a/pilot/scene/chat_knowledge/default/out_parser.py b/pilot/scene/chat_knowledge/default/out_parser.py index 0f7ccd791..0b8277d63 100644 --- a/pilot/scene/chat_knowledge/default/out_parser.py +++ b/pilot/scene/chat_knowledge/default/out_parser.py @@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text - def parse_view_response(self, ai_text) -> str: - return super().parse_view_response(ai_text) - def get_format_instructions(self) -> str: pass diff --git a/pilot/scene/chat_knowledge/url/out_parser.py b/pilot/scene/chat_knowledge/url/out_parser.py index 0f7ccd791..0b8277d63 100644 --- a/pilot/scene/chat_knowledge/url/out_parser.py +++ b/pilot/scene/chat_knowledge/url/out_parser.py @@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text - def parse_view_response(self, ai_text) -> str: - return super().parse_view_response(ai_text) - def get_format_instructions(self) -> str: pass diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 515701255..af91fa4e9 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -246,7 +246,7 @@ def http_bot( state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name ): - logger.info(f"User message send!{state.conv_id},{selected}") + logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}") if chat_mode_title['knowledge_qa'] == selected: scene: ChatScene = get_chat_mode(selected, mode) elif chat_mode_title['chat_use_plugin'] == selected: @@ -417,7 +417,6 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True, - name="db_selector" ).style(container=False) sql_mode = gr.Radio( @@ -426,8 +425,7 @@ def build_single_model_ui(): get_lang_text("sql_generate_mode_none"), ], show_label=False, - value=get_lang_text("sql_generate_mode_none"), - name="sql_mode" + value=get_lang_text("sql_generate_mode_none") ) sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) @@ -444,12 +442,12 @@ def build_single_model_ui(): value="", interactive=True, show_label=True, - type="value", - name="plugin_selector" + type="value" ).style(container=False) def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData print(f"You selected {evt.value} at {evt.index} from {evt.target}") + print(f"user plugin:{plugins_select_info().get(evt.value)}") return plugins_select_info().get(evt.value) plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") @@ -466,14 +464,13 @@ def build_single_model_ui(): ], show_label=False, value=llm_native_dialogue, - name="mode" ) vs_setting = gr.Accordion( get_lang_text("configure_knowledge_base"), open=False, visible=False ) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) - url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False, name="url_input") + url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False) def show_url_input(evt:gr.SelectData): if evt.value == url_knowledge_dialogue: return gr.update(visible=True) @@ -484,7 +481,7 @@ def build_single_model_ui(): with vs_setting: vs_name = gr.Textbox( - label=get_lang_text("new_klg_name"), lines=1, interactive=True, name = "vs_name" + label=get_lang_text("new_klg_name"), lines=1, interactive=True ) vs_add = gr.Button(get_lang_text("add_as_new_klg")) with gr.Column() as doc2vec: @@ -530,7 +527,7 @@ def build_single_model_ui(): gr.Markdown(learn_more_markdown) - params = [plugin_selector, mode, sql_mode, db_selector, url_input, vs_name] + params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name] btn_list = [regenerate_btn, clear_btn] diff --git a/plugins/DB-GPT-Plugin-ByteBase.zip b/plugins/DB-GPT-Plugin-ByteBase.zip new file mode 100644 index 0000000000000000000000000000000000000000..60ce2816dcd42907c41c3be79f7b1cd713b7116a GIT binary patch literal 30614 zcmd431yEkuwl#{oySux4aCc8|cXxLP?iSo35G=U6ySuv+EI9no=ibwO&+GK<|GlbL z{Z(x?Rl92KHRl{-thwhBc`0BJ6oB_1OHjM&Z~y$)A9Mf$0AWE|F?nTLc}ph~GizEw zH%B8uJqIIt6=hfea1jU>L%Fx>{qE!n4FCx86A%F4x5nguYZL+i;(s)X^7feUtx+8v z837?VrBC#9I)C2&|7e~E{{PXufi9|niLSox4;g9IAMmOn^1eVZ7{KHR*nk(r1Ynn# zA{Ze+QG|p72xy=B(FkPumAXJsZ*_Gygy?}wCg)Eqn9UOp_whb-zBXXZs zlr6vj