From 2fc62c16efc509d02c94a350b4043e346aa39abe Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Wed, 24 May 2023 17:33:40 +0800 Subject: [PATCH 01/22] =?UTF-8?q?=E5=A4=9A=E5=9C=BA=E6=99=AF=E5=AF=B9?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E4=B8=80=E6=9C=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.template | 2 +- .gitignore | 2 + pilot/common/formatting.py | 38 +++ pilot/common/markdown_text.py | 48 +++ pilot/common/schema.py | 8 + pilot/common/sql_database.py | 308 ++++++++++++++++++ pilot/configs/config.py | 10 +- pilot/connections/mysql.py | 14 +- pilot/conversation.py | 4 +- pilot/inference_parsers/__init__.py | 0 pilot/inference_parsers/base.py | 71 +++++ pilot/memory/__init__.py | 0 pilot/memory/chat_history/__init__.py | 0 pilot/memory/chat_history/base.py | 40 +++ pilot/memory/chat_history/file_history.py | 45 +++ pilot/out_parser/__init__.py | 0 pilot/out_parser/base.py | 101 ++++++ pilot/prompts/base.py | 51 +++ pilot/prompts/generator_new.py | 0 pilot/prompts/prompt_generator.py | 54 ++++ pilot/prompts/prompt_new.py | 104 +++++++ pilot/prompts/prompt_template.py | 363 ++++++++++++++++++++++ pilot/scene/__init__.py | 0 pilot/scene/base.py | 8 + pilot/scene/base_chat.py | 102 ++++++ pilot/scene/base_message.py | 141 +++++++++ pilot/scene/chat_db/__init__.py | 0 pilot/scene/chat_db/chat.py | 223 +++++++++++++ pilot/scene/chat_db/out_parser.py | 45 +++ pilot/scene/chat_db/prompt.py | 53 ++++ pilot/scene/chat_execution/__init__.py | 0 pilot/scene/chat_execution/chat.py | 26 ++ pilot/scene/chat_factory.py | 20 ++ pilot/scene/chat_knowledge/__init__.py | 0 pilot/scene/chat_normal/__init__.py | 0 pilot/scene/message.py | 80 +++++ pilot/server/webserver.py | 222 ++++++------- 37 files changed, 2052 insertions(+), 131 deletions(-) create mode 100644 pilot/common/formatting.py create mode 100644 pilot/common/markdown_text.py create mode 100644 pilot/common/schema.py create mode 100644 pilot/common/sql_database.py create mode 100644 pilot/inference_parsers/__init__.py create mode 100644 pilot/inference_parsers/base.py create mode 100644 pilot/memory/__init__.py create mode 100644 pilot/memory/chat_history/__init__.py create mode 100644 pilot/memory/chat_history/base.py create mode 100644 pilot/memory/chat_history/file_history.py create mode 100644 pilot/out_parser/__init__.py create mode 100644 pilot/out_parser/base.py create mode 100644 pilot/prompts/base.py create mode 100644 pilot/prompts/generator_new.py create mode 100644 pilot/prompts/prompt_generator.py create mode 100644 pilot/prompts/prompt_new.py create mode 100644 pilot/prompts/prompt_template.py create mode 100644 pilot/scene/__init__.py create mode 100644 pilot/scene/base.py create mode 100644 pilot/scene/base_chat.py create mode 100644 pilot/scene/base_message.py create mode 100644 pilot/scene/chat_db/__init__.py create mode 100644 pilot/scene/chat_db/chat.py create mode 100644 pilot/scene/chat_db/out_parser.py create mode 100644 pilot/scene/chat_db/prompt.py create mode 100644 pilot/scene/chat_execution/__init__.py create mode 100644 pilot/scene/chat_execution/chat.py create mode 100644 pilot/scene/chat_factory.py create mode 100644 pilot/scene/chat_knowledge/__init__.py create mode 100644 pilot/scene/chat_normal/__init__.py create mode 100644 pilot/scene/message.py diff --git a/.env.template b/.env.template index d809a362b..8fd0b51b2 100644 --- a/.env.template +++ b/.env.template @@ -18,7 +18,7 @@ #** LLM MODELS **# #*******************************************************************# LLM_MODEL=vicuna-13b -MODEL_SERVER=http://your_model_server_url +MODEL_SERVER=http://120.79.27.110:8000 LIMIT_MODEL_CONCURRENCY=5 MAX_POSITION_EMBEDDINGS=4096 diff --git a/.gitignore b/.gitignore index cb21ee557..06a210cd9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ __pycache__/ # C extensions *.so +message/ + .env .idea .vscode diff --git a/pilot/common/formatting.py b/pilot/common/formatting.py new file mode 100644 index 000000000..3b3b597b0 --- /dev/null +++ b/pilot/common/formatting.py @@ -0,0 +1,38 @@ +"""Utilities for formatting strings.""" +from string import Formatter +from typing import Any, List, Mapping, Sequence, Union + + +class StrictFormatter(Formatter): + """A subclass of formatter that checks for extra keys.""" + + def check_unused_args( + self, + used_args: Sequence[Union[int, str]], + args: Sequence, + kwargs: Mapping[str, Any], + ) -> None: + """Check to see if extra parameters are passed.""" + extra = set(kwargs).difference(used_args) + if extra: + raise KeyError(extra) + + def vformat( + self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] + ) -> str: + """Check that no arguments are provided.""" + if len(args) > 0: + raise ValueError( + "No arguments should be provided, " + "everything should be passed as keyword arguments." + ) + return super().vformat(format_string, args, kwargs) + + def validate_input_variables( + self, format_string: str, input_variables: List[str] + ) -> None: + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + super().format(format_string, **dummy_inputs) + + +formatter = StrictFormatter() diff --git a/pilot/common/markdown_text.py b/pilot/common/markdown_text.py new file mode 100644 index 000000000..b0f96e15c --- /dev/null +++ b/pilot/common/markdown_text.py @@ -0,0 +1,48 @@ +import markdown2 +import pandas as pd + +def datas_to_table_html(data): + df = pd.DataFrame(data) + table_style = """\n\n""" + html_table = df.to_html(index=False, header=False, border = True) + return table_style + html_table + + + +def generate_markdown_table(data): + """\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n """ + # 获取表格列数 + num_cols = len(data[0]) + # 生成表头 + header = "| " + for i in range(num_cols): + header += data[0][i] + " | " + + # 生成分隔线 + separator = "| " + for i in range(num_cols): + separator += "--- | " + + # 生成表格内容 + content = "" + for row in data[1:]: + content += "| " + for i in range(num_cols): + content += str(row[i]) + " | " + content += "\n" + + # 合并表头、分隔线和表格内容 + table = header + "\n" + separator + "\n" + content + + return table + +def generate_htm_table(data): + markdown_text = generate_markdown_table(data) + html_table = markdown2.markdown(markdown_text, extras=["tables"]) + return html_table + + +if __name__ == "__main__": + mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n" + + print(generate_htm_table(mk_text)) \ No newline at end of file diff --git a/pilot/common/schema.py b/pilot/common/schema.py new file mode 100644 index 000000000..f66bba1a6 --- /dev/null +++ b/pilot/common/schema.py @@ -0,0 +1,8 @@ +from enum import auto, Enum +from typing import List, Any + +class SeparatorStyle(Enum): + SINGLE ="###" + TWO = "" + THREE = auto() + FOUR = auto() diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py new file mode 100644 index 000000000..2c16869d5 --- /dev/null +++ b/pilot/common/sql_database.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import warnings +from typing import Any, Iterable, List, Optional +from pydantic import BaseModel, Field, root_validator, validator, Extra +from abc import ABC, abstractmethod +import sqlalchemy +from sqlalchemy import ( + MetaData, + Table, + create_engine, + inspect, + select, + text, +) +from sqlalchemy.engine import CursorResult, Engine +from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.schema import CreateTable +from sqlalchemy.orm import sessionmaker, scoped_session + + +def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: + return ( + f'Name: {index["name"]}, Unique: {index["unique"]},' + f' Columns: {str(index["column_names"])}' + ) + + +class Database: + """SQLAlchemy wrapper around a database.""" + + def __init__( + self, + engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + indexes_in_table_info: bool = False, + custom_table_info: Optional[dict] = None, + view_support: bool = False, + ): + """Create engine from database URI.""" + self._engine = engine + self._schema = schema + if include_tables and ignore_tables: + raise ValueError("Cannot specify both include_tables and ignore_tables") + + self._inspector = inspect(self._engine) + session_factory = sessionmaker(bind=engine) + Session = scoped_session(session_factory) + + self._db_sessions = Session + + self._all_tables = set() + self.view_support = False + self._usable_tables = set() + self._include_tables = set() + self._ignore_tables = set() + self._custom_table_info = set() + self._indexes_in_table_info = set() + self._usable_tables = set() + self._usable_tables = set() + self._sample_rows_in_table_info = set() + # including view support by adding the views as well as tables to the all + # tables list if view_support is True + # self._all_tables = set( + # self._inspector.get_table_names(schema=schema) + # + (self._inspector.get_view_names(schema=schema) if view_support else []) + # ) + + # self._include_tables = set(include_tables) if include_tables else set() + # if self._include_tables: + # missing_tables = self._include_tables - self._all_tables + # if missing_tables: + # raise ValueError( + # f"include_tables {missing_tables} not found in database" + # ) + # self._ignore_tables = set(ignore_tables) if ignore_tables else set() + # if self._ignore_tables: + # missing_tables = self._ignore_tables - self._all_tables + # if missing_tables: + # raise ValueError( + # f"ignore_tables {missing_tables} not found in database" + # ) + # usable_tables = self.get_usable_table_names() + # self._usable_tables = set(usable_tables) if usable_tables else self._all_tables + + # if not isinstance(sample_rows_in_table_info, int): + # raise TypeError("sample_rows_in_table_info must be an integer") + # + # self._sample_rows_in_table_info = sample_rows_in_table_info + # self._indexes_in_table_info = indexes_in_table_info + # + # self._custom_table_info = custom_table_info + # if self._custom_table_info: + # if not isinstance(self._custom_table_info, dict): + # raise TypeError( + # "table_info must be a dictionary with table names as keys and the " + # "desired table info as values" + # ) + # # only keep the tables that are also present in the database + # intersection = set(self._custom_table_info).intersection(self._all_tables) + # self._custom_table_info = dict( + # (table, self._custom_table_info[table]) + # for table in self._custom_table_info + # if table in intersection + # ) + + # self._metadata = metadata or MetaData() + # # # including view support if view_support = true + # self._metadata.reflect( + # views=view_support, + # bind=self._engine, + # only=list(self._usable_tables), + # schema=self._schema, + # ) + + @classmethod + def from_uri( + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> Database: + """Construct a SQLAlchemy engine from URI.""" + _engine_args = engine_args or {} + return cls(create_engine(database_uri, **_engine_args), **kwargs) + + @property + def dialect(self) -> str: + """Return string representation of dialect to use.""" + return self._engine.dialect.name + + def get_usable_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + if self._include_tables: + return self._include_tables + return self._all_tables - self._ignore_tables + + def get_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + warnings.warn( + "This method is deprecated - please use `get_usable_table_names`." + ) + return self.get_usable_table_names() + + def get_session(self, db_name: str): + session = self._db_sessions() + + self._metadata = MetaData() + # sql = f"use {db_name}" + sql = text(f'use `{db_name}`') + session.execute(sql) + + # 处理表信息数据 + + self._metadata.reflect(bind=self._engine, schema=db_name) + + # including view support by adding the views as well as tables to the all + # tables list if view_support is True + self._all_tables = set( + self._inspector.get_table_names(schema=db_name) + + (self._inspector.get_view_names(schema=db_name) if self.view_support else []) + ) + + return session + + def get_current_db_name(self, session) -> str: + return session.execute(text('SELECT DATABASE()')).scalar() + + def table_simple_info(self, session): + _sql = f""" + select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME; + """ + cursor = session.execute(text(_sql)) + results = cursor.fetchall() + return results + + @property + def table_info(self) -> str: + """Information about all tables in the database.""" + return self.get_table_info() + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables. + + Follows best practices as specified in: Rajkumar et al, 2022 + (https://arxiv.org/abs/2204.00498) + + If `sample_rows_in_table_info`, the specified number of sample rows will be + appended to each table description. This can increase performance as + demonstrated in the paper. + """ + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + all_table_names = table_names + + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) + ] + + tables = [] + for table in meta_tables: + if self._custom_table_info and table.name in self._custom_table_info: + tables.append(self._custom_table_info[table.name]) + continue + + # add create table command + create_table = str(CreateTable(table).compile(self._engine)) + table_info = f"{create_table.rstrip()}" + has_extra_info = ( + self._indexes_in_table_info or self._sample_rows_in_table_info + ) + if has_extra_info: + table_info += "\n\n/*" + if self._indexes_in_table_info: + table_info += f"\n{self._get_table_indexes(table)}\n" + if self._sample_rows_in_table_info: + table_info += f"\n{self._get_sample_rows(table)}\n" + if has_extra_info: + table_info += "*/" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def _get_sample_rows(self, table: Table) -> str: + # build the select command + command = select(table).limit(self._sample_rows_in_table_info) + + # save the columns in string format + columns_str = "\t".join([col.name for col in table.columns]) + + try: + # get the sample rows + with self._engine.connect() as connection: + sample_rows_result: CursorResult = connection.execute(command) + # shorten values in the sample rows + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) + ) + + # save the sample rows in string format + sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) + + # in some dialects when there are no rows in the table a + # 'ProgrammingError' is returned + except ProgrammingError: + sample_rows_str = "" + + return ( + f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" + f"{columns_str}\n" + f"{sample_rows_str}" + ) + + def _get_table_indexes(self, table: Table) -> str: + indexes = self._inspector.get_indexes(table.name) + indexes_formatted = "\n".join(map(_format_index, indexes)) + return f"Table Indexes:\n{indexes_formatted}" + + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables.""" + try: + return self.get_table_info(table_names) + except ValueError as e: + """Format the error message""" + return f"Error: {e}" + + def run(self, session, command: str, fetch: str = "all") -> List: + """Execute a SQL command and return a string representing the results.""" + cursor = session.execute(text(command)) + if cursor.returns_rows: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone()[0] # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + field_names = tuple(i[0:] for i in cursor.keys()) + + result = list(result) + result.insert(0, field_names) + return result + + def run_no_throw(self, session, command: str, fetch: str = "all") -> List: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + If the statement throws an error, the error message is returned. + """ + try: + return self.run(session, command, fetch) + except SQLAlchemyError as e: + """Format the error message""" + return f"Error: {e}" + + def get_database_list(self): + session = self._db_sessions() + cursor = session.execute(text(' show databases;')) + results = cursor.fetchall() + return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]] diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 9023bc061..a246cd43b 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -7,6 +7,7 @@ from typing import List from auto_gpt_plugin_template import AutoGPTPluginTemplate from pilot.singleton import Singleton +from pilot.common.sql_database import Database class Config(metaclass=Singleton): @@ -62,7 +63,7 @@ class Config(metaclass=Singleton): ) self.speak_mode = False - + self.prompt_templates = {} ### Related configuration of built-in commands self.command_registry = [] @@ -75,7 +76,8 @@ class Config(metaclass=Singleton): self.execute_local_commands = ( os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" ) - + ### message stor file + self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") ### The associated configuration parameters of the plug-in control the loading and use of the plug-in self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") @@ -101,6 +103,10 @@ class Config(metaclass=Singleton): self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") + ### TODO Adapt to multiple types of libraries + self.local_db = Database.from_uri("mysql+pymysql://" + self.LOCAL_DB_USER +":"+ self.LOCAL_DB_PASSWORD +"@" +self.LOCAL_DB_HOST + ":" + str(self.LOCAL_DB_PORT) , + engine_args ={"pool_size": 10, "pool_recycle": 3600, "echo": True}) + ### LLM Model Service Configuration self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) diff --git a/pilot/connections/mysql.py b/pilot/connections/mysql.py index 83da27ec3..fb2d34f1e 100644 --- a/pilot/connections/mysql.py +++ b/pilot/connections/mysql.py @@ -31,7 +31,19 @@ class MySQLOperator: cursor.execute(_sql) results = cursor.fetchall() return results - + + def run_sql(self, db_name:str, sql:str, fetch: str = "all"): + with self.conn.cursor() as cursor: + cursor.execute("USE " + db_name) + cursor.execute(sql) + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone()[0] # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + return str(result) + def get_index(self, schema_name): pass diff --git a/pilot/conversation.py b/pilot/conversation.py index 7054fb453..55f19e3d7 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -2,6 +2,7 @@ # -*- coding:utf-8 -*- import dataclasses +import uuid from enum import auto, Enum from typing import List, Any from pilot.configs.config import Config @@ -36,7 +37,7 @@ class Conversation: # Used for gradio server skip_next: bool = False conv_id: Any = None - + last_user_input: Any = None def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep @@ -258,6 +259,7 @@ conversation_types = { "native": "LLM原生对话", "default_knownledge": "默认知识库对话", "custome": "新增知识库对话", + "auto_execute_plugin": "对话使用插件", } conv_templates = { diff --git a/pilot/inference_parsers/__init__.py b/pilot/inference_parsers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/inference_parsers/base.py b/pilot/inference_parsers/base.py new file mode 100644 index 000000000..f317d876a --- /dev/null +++ b/pilot/inference_parsers/base.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from pydantic import BaseModel, Extra, Field, root_validator + + +class BaseOutputParser(BaseModel, ABC, Generic[T]): + """Class to parse the output of an LLM call. + + Output parsers help structure language model responses. + """ + + @abstractmethod + def parse(self, text: str) -> T: + """Parse the output of an LLM call. + + A method which takes in a string (assumed output of language model ) + and parses it into some structure. + + Args: + text: output of language model + + Returns: + structured output + """ + + def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + """Optional method to parse the output of an LLM call with a prompt. + + The prompt is largely provided in the event the OutputParser wants + to retry or fix the output in some way, and needs information from + the prompt to do so. + + Args: + completion: output of language model + prompt: prompt value + + Returns: + structured output + """ + return self.parse(completion) + + def get_format_instructions(self) -> str: + """Instructions on how the LLM output should be formatted.""" + raise NotImplementedError + + @property + def _type(self) -> str: + """Return the type key.""" + raise NotImplementedError( + f"_type property is not implemented in class {self.__class__.__name__}." + " This is required for serialization." + ) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict() + output_parser_dict["_type"] = self._type + return output_parser_dict + + +class OutputParserException(Exception): + """Exception that output parsers should raise to signify a parsing error. + + This exists to differentiate parsing errors from other code or execution errors + that also may arise inside the output parser. OutputParserExceptions will be + available to catch and handle in ways to fix the parsing error, while other + errors will be raised. + """ + + pass diff --git a/pilot/memory/__init__.py b/pilot/memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/memory/chat_history/__init__.py b/pilot/memory/chat_history/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py new file mode 100644 index 000000000..b71ad4a5f --- /dev/null +++ b/pilot/memory/chat_history/base.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field, root_validator, validator,Extra +from abc import ABC, abstractmethod +from typing import ( + Any, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) + +from pilot.scene.message import OnceConversation + + + + + +class BaseChatHistoryMemory(ABC): + + def __init__(self): + self.conversations:List[OnceConversation] = [] + + @abstractmethod + def messages(self) -> List[OnceConversation]: # type: ignore + """Retrieve the messages from the local file""" + + @abstractmethod + def append(self, message: OnceConversation) -> None: + """Append the message to the record in the local file""" + + + @abstractmethod + def clear(self) -> None: + """Clear session memory from the local file""" + diff --git a/pilot/memory/chat_history/file_history.py b/pilot/memory/chat_history/file_history.py new file mode 100644 index 000000000..a3d53415b --- /dev/null +++ b/pilot/memory/chat_history/file_history.py @@ -0,0 +1,45 @@ +from typing import List +import json +import os +import datetime +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pathlib import Path + +from pilot.configs.config import Config +from pilot.scene.message import OnceConversation, conversation_from_dict,conversations_to_dict + + +CFG = Config() + + +class FileHistoryMemory(BaseChatHistoryMemory): + def __init__(self, chat_session_id:str): + now = datetime.datetime.now() + date_string = now.strftime("%Y%m%d") + path: str = f"{CFG.message_dir}/{date_string}" + os.makedirs(path, exist_ok=True) + + dir_path = Path(path) + self.file_path = Path(dir_path / f"{chat_session_id}.json") + if not self.file_path.exists(): + self.file_path.touch() + self.file_path.write_text(json.dumps([])) + + + + def messages(self) -> List[OnceConversation]: + items = json.loads(self.file_path.read_text()) + history:List[OnceConversation] = [] + for onece in items: + messages = conversation_from_dict(onece) + history.append(messages) + return history + + def append(self, once_message: OnceConversation) -> None: + historys = self.messages() + historys.append(once_message) + self.file_path.write_text(json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4), encoding="UTF-8") + + def clear(self) -> None: + self.file_path.write_text(json.dumps([])) + diff --git a/pilot/out_parser/__init__.py b/pilot/out_parser/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py new file mode 100644 index 000000000..1c80d4a08 --- /dev/null +++ b/pilot/out_parser/base.py @@ -0,0 +1,101 @@ +from __future__ import annotations +import json + +from abc import ABC, abstractmethod +from typing import ( + Any, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) + +from pydantic import BaseModel, Extra, Field, root_validator + +from pilot.prompts.base import PromptValue + +T = TypeVar("T") + + +class BaseOutputParser(BaseModel, ABC, Generic[T]): + """Class to parse the output of an LLM call. + + Output parsers help structure language model responses. + """ + + def parse_model_nostream_resp(self, response, sep: str): + text = response.text.strip() + text = text.rstrip() + respObj = json.loads(text) + + xx = respObj['response'] + xx = xx.strip(b'\x00'.decode()) + respObj_ex = json.loads(xx) + if respObj_ex['error_code'] == 0: + all_text = respObj_ex['text'] + ### 解析返回文本,获取AI回复部分 + tmpResp = all_text.split(sep) + last_index = -1 + for i in range(len(tmpResp)): + if tmpResp[i].find('ASSISTANT:') != -1: + last_index = i + ai_response = tmpResp[last_index] + ai_response = ai_response.replace("ASSISTANT:", "") + ai_response = ai_response.replace("\n", "") + ai_response = ai_response.replace("\_", "_") + print("un_stream clear response:{}", ai_response) + return ai_response + else: + raise ValueError("Model server error!code=" + respObj_ex['error_code']); + + @abstractmethod + def parse(self, text: str) -> T: + """Parse the output of an LLM call. + + A method which takes in a string (assumed output of language model ) + and parses it into some structure. + + Args: + text: output of language model + + Returns: + structured output + """ + + def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + """Optional method to parse the output of an LLM call with a prompt. + + The prompt is largely provided in the event the OutputParser wants + to retry or fix the output in some way, and needs information from + the prompt to do so. + + Args: + completion: output of language model + prompt: prompt value + + Returns: + structured output + """ + return self.parse(completion) + + def get_format_instructions(self) -> str: + """Instructions on how the LLM output should be formatted.""" + raise NotImplementedError + + @property + def _type(self) -> str: + """Return the type key.""" + raise NotImplementedError( + f"_type property is not implemented in class {self.__class__.__name__}." + " This is required for serialization." + ) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict() + output_parser_dict["_type"] = self._type + return output_parser_dict diff --git a/pilot/prompts/base.py b/pilot/prompts/base.py new file mode 100644 index 000000000..bd082000e --- /dev/null +++ b/pilot/prompts/base.py @@ -0,0 +1,51 @@ + + +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + +import yaml +from pydantic import BaseModel, Extra, Field, root_validator + +from pilot.scene.base_message import BaseMessage,HumanMessage,AIMessage, SystemMessage + + +def get_buffer_string( + messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: + """Get buffer string of messages.""" + string_messages = [] + for m in messages: + if isinstance(m, HumanMessage): + role = human_prefix + elif isinstance(m, AIMessage): + role = ai_prefix + elif isinstance(m, SystemMessage): + role = "System" + else: + raise ValueError(f"Got unsupported message type: {m}") + string_messages.append(f"{role}: {m.content}") + return "\n".join(string_messages) + + + +class PromptValue(BaseModel, ABC): + @abstractmethod + def to_string(self) -> str: + """Return prompt as string.""" + + @abstractmethod + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + +class ChatPromptValue(PromptValue): + messages: List[BaseMessage] + + def to_string(self) -> str: + """Return prompt as string.""" + return get_buffer_string(self.messages) + + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + return self.messages \ No newline at end of file diff --git a/pilot/prompts/generator_new.py b/pilot/prompts/generator_new.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/prompts/prompt_generator.py b/pilot/prompts/prompt_generator.py new file mode 100644 index 000000000..e0ffed4a6 --- /dev/null +++ b/pilot/prompts/prompt_generator.py @@ -0,0 +1,54 @@ +from typing import Any, Callable, Dict, List, Optional + + +class PromptGenerator: + """ + generating custom prompt strings based on constraints; + Compatible with AutoGpt Plugin; + """ + + def __init__(self) -> None: + """ + Initialize the PromptGenerator object with empty lists of constraints, + commands, resources, and performance evaluations. + """ + self.constraints = [] + self.commands = [] + self.resources = [] + self.performance_evaluation = [] + self.goals = [] + self.command_registry = None + self.name = "Bob" + self.role = "AI" + self.response_format = None + + + + def add_command( + self, + command_label: str, + command_name: str, + args=None, + function: Optional[Callable] = None, + ) -> None: + """ + Add a command to the commands list with a label, name, and optional arguments. + GB-GPT and Auto-GPT plugin registration command. + Args: + command_label (str): The label of the command. + command_name (str): The name of the command. + args (dict, optional): A dictionary containing argument names and their + values. Defaults to None. + function (callable, optional): A callable function to be called when + the command is executed. Defaults to None. + """ + if args is None: + args = {} + command_args = {arg_key: arg_value for arg_key, arg_value in args.items()} + command = { + "label": command_label, + "name": command_name, + "args": command_args, + "function": function, + } + self.commands.append(command) \ No newline at end of file diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py new file mode 100644 index 000000000..8dcdcdf51 --- /dev/null +++ b/pilot/prompts/prompt_new.py @@ -0,0 +1,104 @@ +import json +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union +from pydantic import BaseModel, Extra, Field, root_validator + + +from pilot.common.formatting import formatter +from pilot.out_parser.base import BaseOutputParser + + +def jinja2_formatter(template: str, **kwargs: Any) -> str: + """Format a template using jinja2.""" + try: + from jinja2 import Template + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) + + return Template(template).render(**kwargs) + + +DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { + "f-string": formatter.format, + "jinja2": jinja2_formatter, +} + + +class PromptTemplate(BaseModel, ABC): + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + template_scene: str + template: str + """The prompt template.""" + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + response_format:str + output_parser: BaseOutputParser = None + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "prompt" + + def _generate_command_string(self, command: Dict[str, Any]) -> str: + """ + Generate a formatted string representation of a command. + + Args: + command (dict): A dictionary containing command information. + + Returns: + str: The formatted command string. + """ + args_string = ", ".join( + f'"{key}": "{value}"' for key, value in command["args"].items() + ) + return f'{command["label"]}: "{command["name"]}", args: {args_string}' + + def _generate_numbered_list(self, items: List[Any], item_type="list") -> str: + """ + Generate a numbered list from given items based on the item_type. + + Args: + items (list): A list of items to be numbered. + item_type (str, optional): The type of items in the list. + Defaults to 'list'. + + Returns: + str: The formatted numbered list. + """ + if item_type == "command": + command_strings = [] + if self.command_registry: + command_strings += [ + str(item) + for item in self.command_registry.commands.values() + if item.enabled + ] + # terminate command is added manually + command_strings += [self._generate_command_string(item) for item in items] + return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) + else: + return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) + + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + """ + + kwargs["response"] = json.dumps(self.response_format, indent=4) + return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + + def add_goals(self, goal: str) -> None: + self.goals.append(goal) + + def add_constraint(self, constraint: str) -> None: + """ + Add a constraint to the constraints list. + + Args: + constraint (str): The constraint to be added. + """ + self.constraints.append(constraint) diff --git a/pilot/prompts/prompt_template.py b/pilot/prompts/prompt_template.py new file mode 100644 index 000000000..ad597c33d --- /dev/null +++ b/pilot/prompts/prompt_template.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import json +import yaml +from string import Formatter +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + +from pydantic import BaseModel, Extra, Field, root_validator + +from pilot.out_parser.base import BaseOutputParser +from pilot.prompts.base import PromptValue +from pilot.scene.base_message import HumanMessage, AIMessage, SystemMessage, BaseMessage +from pilot.common.formatting import formatter + + +def jinja2_formatter(template: str, **kwargs: Any) -> str: + """Format a template using jinja2.""" + try: + from jinja2 import Template + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) + + return Template(template).render(**kwargs) + + +def validate_jinja2(template: str, input_variables: List[str]) -> None: + input_variables_set = set(input_variables) + valid_variables = _get_jinja2_variables_from_template(template) + missing_variables = valid_variables - input_variables_set + extra_variables = input_variables_set - valid_variables + + error_message = "" + if missing_variables: + error_message += f"Missing variables: {missing_variables} " + + if extra_variables: + error_message += f"Extra variables: {extra_variables}" + + if error_message: + raise KeyError(error_message.strip()) + + +def _get_jinja2_variables_from_template(template: str) -> Set[str]: + try: + from jinja2 import Environment, meta + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) + env = Environment() + ast = env.parse(template) + variables = meta.find_undeclared_variables(ast) + return variables + + +DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { + "f-string": formatter.format, + "jinja2": jinja2_formatter, +} + +DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { + "f-string": formatter.validate_input_variables, + "jinja2": validate_jinja2, +} + + +def check_valid_template( + template: str, template_format: str, input_variables: List[str] +) -> None: + """Check that template string is valid.""" + if template_format not in DEFAULT_FORMATTER_MAPPING: + valid_formats = list(DEFAULT_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + try: + validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] + validator_func(template, input_variables) + except KeyError as e: + raise ValueError( + "Invalid prompt schema; check for mismatched or missing input parameters. " + + str(e) + ) + + +class BasePromptTemplate(BaseModel, ABC): + """Base class for all prompt templates, returning a prompt.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + output_parser: Optional[BaseOutputParser] = None + """How to parse the output of calling an LLM on this formatted prompt.""" + partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( + default_factory=dict + ) + + @abstractmethod + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Create Chat Messages.""" + + @root_validator() + def validate_variable_names(cls, values: Dict) -> Dict: + """Validate variable names do not include restricted names.""" + if "stop" in values["input_variables"]: + raise ValueError( + "Cannot have an input variable named 'stop', as it is used internally," + " please rename." + ) + if "stop" in values["partial_variables"]: + raise ValueError( + "Cannot have an partial variable named 'stop', as it is used " + "internally, please rename." + ) + + overall = set(values["input_variables"]).intersection( + values["partial_variables"] + ) + if overall: + raise ValueError( + f"Found overlapping input and partial variables: {overall}" + ) + return values + + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: + """Return a partial of the prompt template.""" + prompt_dict = self.__dict__.copy() + prompt_dict["input_variables"] = list( + set(self.input_variables).difference(kwargs) + ) + prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} + return type(self)(**prompt_dict) + + def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: + # Get partial params: + partial_kwargs = { + k: v if isinstance(v, str) else v() + for k, v in self.partial_variables.items() + } + return {**partial_kwargs, **kwargs} + + @abstractmethod + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of prompt.""" + prompt_dict = super().dict(**kwargs) + prompt_dict["_type"] = self._prompt_type + return prompt_dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the prompt. + + Args: + file_path: Path to directory to save prompt to. + + Example: + .. code-block:: python + + prompt.save(file_path="path/prompt.yaml") + """ + if self.partial_variables: + raise ValueError("Cannot save prompt with partial variables.") + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + prompt_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(prompt_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(prompt_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + +class StringPromptValue(PromptValue): + text: str + + def to_string(self) -> str: + """Return prompt as string.""" + return self.text + + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + return [HumanMessage(content=self.text)] + + + +class StringPromptTemplate(BasePromptTemplate, ABC): + """String prompt should expose the format method, returning a prompt.""" + + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Create Chat Messages.""" + return StringPromptValue(text=self.format(**kwargs)) + + +class PromptTemplate(StringPromptTemplate): + """Schema to represent a prompt for an LLM. + + Example: + .. code-block:: python + + from langchain import PromptTemplate + prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") + """ + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + template: str + """The prompt template.""" + + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + + validate_template: bool = True + """Whether or not to try validating the template.""" + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "prompt" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + kwargs = self._merge_partial_and_user_variables(**kwargs) + return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" + if values["validate_template"]: + all_inputs = values["input_variables"] + list(values["partial_variables"]) + check_valid_template( + values["template"], values["template_format"], all_inputs + ) + return values + + @classmethod + def from_examples( + cls, + examples: List[str], + suffix: str, + input_variables: List[str], + example_separator: str = "\n\n", + prefix: str = "", + **kwargs: Any, + ) -> PromptTemplate: + """Take examples in list format with prefix and suffix to create a prompt. + + Intended to be used as a way to dynamically create a prompt from examples. + + Args: + examples: List of examples to use in the prompt. + suffix: String to go after the list of examples. Should generally + set up the user's input. + input_variables: A list of variable names the final prompt template + will expect. + example_separator: The separator to use in between examples. Defaults + to two new line characters. + prefix: String that should go before any examples. Generally includes + examples. Default to an empty string. + + Returns: + The final prompt generated. + """ + template = example_separator.join([prefix, *examples, suffix]) + return cls(input_variables=input_variables, template=template, **kwargs) + + @classmethod + def from_file( + cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any + ) -> PromptTemplate: + """Load a prompt from a file. + + Args: + template_file: The path to the file containing the prompt template. + input_variables: A list of variable names the final prompt template + will expect. + Returns: + The prompt loaded from the file. + """ + with open(str(template_file), "r") as f: + template = f.read() + return cls(input_variables=input_variables, template=template, **kwargs) + + @classmethod + def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: + """Load a prompt template from a template.""" + if "template_format" in kwargs and kwargs["template_format"] == "jinja2": + # Get the variables for the template + input_variables = _get_jinja2_variables_from_template(template) + + else: + input_variables = { + v for _, v, _, _ in Formatter().parse(template) if v is not None + } + + if "partial_variables" in kwargs: + partial_variables = kwargs["partial_variables"] + input_variables = { + var for var in input_variables if var not in partial_variables + } + + return cls( + input_variables=list(sorted(input_variables)), template=template, **kwargs + ) + + +# For backwards compatibility. +Prompt = PromptTemplate \ No newline at end of file diff --git a/pilot/scene/__init__.py b/pilot/scene/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/base.py b/pilot/scene/base.py new file mode 100644 index 000000000..302510f2b --- /dev/null +++ b/pilot/scene/base.py @@ -0,0 +1,8 @@ +from enum import Enum + +class ChatScene(Enum): + ChatWithDb = "chat_with_db" + ChatExecution = "chat_execution" + ChatKnowledge = "chat_default_knowledge" + ChatNewKnowledge = "chat_new_knowledge" + ChatNormal = "chat_normal" \ No newline at end of file diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py new file mode 100644 index 000000000..c283df625 --- /dev/null +++ b/pilot/scene/base_chat.py @@ -0,0 +1,102 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel, Field, root_validator, validator, Extra +from typing import ( + Any, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) +import requests +from urllib.parse import urljoin + +import pilot.configs.config +from pilot.scene.message import OnceConversation +from pilot.prompts.prompt_new import PromptTemplate +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pilot.memory.chat_history.file_history import FileHistoryMemory + +from pilot.configs.model_config import LOGDIR, DATASETS_DIR +from pilot.utils import ( + build_logger, + server_error_msg, +) +from pilot.common.schema import SeparatorStyle +from pilot.scene.base import ChatScene +from pilot.configs.config import Config + +logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") +headers = {"User-Agent": "dbgpt Client"} +CFG = Config() +class BaseChat(ABC): + chat_scene: str = None + memory: BaseChatHistoryMemory + llm_model: Any = None + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + temperature: float = 0.6 + max_new_tokens: int = 1024 + # By default, keep the last two rounds of conversation records as the context + chat_retention_rounds: int = 2 + sep = SeparatorStyle.SINGLE.value + class Config: + """Configuration for this pydantic object.""" + arbitrary_types_allowed = True + + def __init__(self, chat_mode, chat_session_id, current_user_input): + self.chat_session_id = chat_session_id + self.chat_mode = chat_mode + self.current_user_input:str = current_user_input + self.llm_model = CFG.LLM_MODEL + ### TODO + self.memory = FileHistoryMemory(chat_session_id) + ### load prompt template + self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] + self.history_message: List[OnceConversation] = [] + self.current_message: OnceConversation = OnceConversation() + self.current_tokens_used: int = 0 + ### load chat_session_id's chat historys + self._load_history(self.chat_session_id) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def chat_type(self) -> str: + raise NotImplementedError("Not supported for this chat type.") + + + def call(self): + pass + + def chat_show(self): + pass + + def current_ai_response(self) -> str: + pass + + def _load_history(self, session_id: str) -> List[OnceConversation]: + """ + load chat history by session_id + Args: + session_id: + Returns: + """ + return self.memory.messages() + + def generate(self, p)->str: + """ + generate context for LLM input + Args: + p: + + Returns: + + """ + pass diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py new file mode 100644 index 000000000..504b4e152 --- /dev/null +++ b/pilot/scene/base_message.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + Any, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) + +from pydantic import BaseModel, Extra, Field, root_validator + +class PromptValue(BaseModel, ABC): + @abstractmethod + def to_string(self) -> str: + """Return prompt as string.""" + + @abstractmethod + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + + +class BaseMessage(BaseModel): + """Message object.""" + + content: str + additional_kwargs: dict = Field(default_factory=dict) + + @property + @abstractmethod + def type(self) -> str: + """Type of the message, used for serialization.""" + + + +class HumanMessage(BaseMessage): + """Type of message that is spoken by the human.""" + + example: bool = False + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "human" + + + +class AIMessage(BaseMessage): + """Type of message that is spoken by the AI.""" + + example: bool = False + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "ai" + + +class SystemMessage(BaseMessage): + """Type of message that is a system message.""" + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "system" + + + + +class Generation(BaseModel): + """Output of a single generation.""" + + text: str + """Generated text output.""" + + generation_info: Optional[Dict[str, Any]] = None + """Raw generation info response from the provider""" + """May include things like reason for finishing (e.g. in OpenAI)""" + + + +class ChatGeneration(Generation): + """Output of a single generation.""" + + text = "" + message: BaseMessage + + @root_validator + def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: + values["text"] = values["message"].content + return values + + +class ChatResult(BaseModel): + """Class that contains all relevant information for a Chat Result.""" + + generations: List[ChatGeneration] + """List of the things generated.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" + + +class LLMResult(BaseModel): + """Class that contains all relevant information for an LLM Result.""" + + generations: List[List[Generation]] + """List of the things generated. This is List[List[]] because + each input could have multiple generations.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" + + + +def _message_to_dict(message: BaseMessage) -> dict: + return {"type": message.type, "data": message.dict()} + + +def messages_to_dict(messages: List[BaseMessage]) -> List[dict]: + return [_message_to_dict(m) for m in messages] + + +def _message_from_dict(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected type: {_type}") + + + +def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: + return [_message_from_dict(m) for m in messages] diff --git a/pilot/scene/chat_db/__init__.py b/pilot/scene/chat_db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py new file mode 100644 index 000000000..ec11c49d7 --- /dev/null +++ b/pilot/scene/chat_db/chat.py @@ -0,0 +1,223 @@ +import requests +import datetime +import threading +import json +from urllib.parse import urljoin +from sqlalchemy import ( + MetaData, + Table, + create_engine, + inspect, + select, + text, +) +from typing import Any, Iterable, List, Optional + +from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config +from pilot.scene.chat_db.out_parser import SqlAction +from pilot.configs.model_config import LOGDIR, DATASETS_DIR +from pilot.utils import ( + build_logger, + server_error_msg, +) +from pilot.common.markdown_text import generate_markdown_table,generate_htm_table,datas_to_table_html +from pilot.scene.chat_db.prompt import chat_db_prompt + +CFG = Config() + + +class ChatWithDb(BaseChat): + chat_scene: str = ChatScene.ChatWithDb.value + + """Number of results to return from the query""" + + def __init__(self, chat_session_id, db_name, user_input): + """ + """ + super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input) + if not db_name: + raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!") + self.db_name = db_name + self.database = CFG.local_db + # 准备DB信息(拿到指定库的链接) + self.db_connect = self.database.get_session(self.db_name) + self.top_k: int = 5 + + def call(self) -> str: + input_values = { + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect), + # "stop": self.sep_style, + } + + ### Chat sequence advance + self.current_message.chat_order = len(self.history_message) + 1 + self.current_message.add_user_message(self.current_user_input) + self.current_message.start_date = datetime.datetime.now() + # TODO + self.current_message.tokens = 0 + + current_prompt = self.prompt_template.format(**input_values) + + ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 + if self.history_message: + ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 + logger.info(f"There are already {len(self.history_message)} rounds of conversations!") + + self.current_message.add_system_message(current_prompt) + + payload = { + "model": self.llm_model, + "prompt": self.generate_llm_text(), + "temperature": float(self.temperature), + "max_new_tokens": int(self.max_new_tokens), + "stop": self.sep_style.value, + } + logger.info(f"Requert: \n{payload}") + + try: + ### 走非流式的模型服务接口 + + # TODO - TEST + # response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120) + # clear_response = self.prompt_template.output_parser.parse_model_nostream_resp(response, self.sep_style) + # sql_action = self.prompt_template.output_parser.parse(clear_response) + resp_test = { + "SQL": "select * from users", + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user" + } + } + + sql_action = SqlAction(**resp_test) + + # self.current_message.add_ai_message(json.dumps(sql_action._asdict())) + + result = self.database.run(self.db_connect, sql_action.SQL) + + self.current_message.add_ai_message(f"{datas_to_table_html(result)}") + + except Exception as e: + logger.error("model response parase faild!" + str(e)) + self.current_message.add_ai_message(str(e)) + ### 对话记录存储 + self.memory.append(self.current_message) + + + def chat_show(self): + ret = [] + # 单论对话只能有一次User 记录 和一次 AI 记录 + # TODO 推理过程前端展示。。。 + for message in enumerate(self.current_message.messages): + if (isinstance(message, HumanMessage)): + ret[-1][-2] = message.content + if (isinstance(message, AIMessage)): + ret[-1][-1] = message.content + return ret + + # 暂时为了兼容前端 + def current_ai_response(self)->str: + for message in self.current_message.messages: + if message.type == 'ai': + return message.content + return None + + + def generate_llm_text(self) -> str: + text = "" + ### 线处理历史信息 + if (len(self.history_message) > self.chat_retention_rounds): + ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录 + for first_message in self.history_message[0].messages: + text += first_message.type + ":" + first_message.content + self.sep + + index = self.chat_retention_rounds - 1 + for last_message in self.history_message[-index:].messages: + text += last_message.type + ":" + last_message.content + self.sep + + else: + ### 直接历史记录拼接 + for conversation in self.history_message: + for message in conversation.messages: + text += message.type + ":" + message.content + self.sep + + ### current conversation + for now_message in self.current_message.messages: + text += now_message.type + ":" + now_message.content + self.sep + + return text + + + @property + def chat_type(self) -> str: + return ChatScene.ChatWithDb.value + + +if __name__ == "__main__": + # chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息") + # + # chat.call() + # + # resp = chat.chat_show() + # + # print(vars(resp)) + + # memory = FileHistoryMemory("test123") + # once1 = OnceConversation() + # once1.add_user_message("问题测试") + # once1.add_system_message("prompt1") + # once1.add_system_message("prompt2") + # once1.chat_order = 1 + # once1.set_start_time(datetime.datetime.now()) + # memory.append(once1) + # + # once = OnceConversation() + # once.add_user_message("问题测试2") + # once.add_system_message("prompt3") + # once.add_system_message("prompt4") + # once.chat_order = 2 + # once.set_start_time(datetime.datetime.now()) + # memory.append(once) + + db: Database = CFG.local_db + db_connect = db.get_session("gpt-user") + data = db.run(db_connect, "select * from users") + print(generate_htm_table(data)) + + + # + # print(db.run(db_connect, "select * from users")) + # + # # + # # def print_numbers(): + # # db_connect1 = db.get_session("dbgpt-test") + # # cursor1 = db_connect1.execute(text("select * from test_name")) + # # if cursor1.returns_rows: + # # result1 = cursor1.fetchall() + # # print( result1) + # # + # # + # # # 创建线程 + # # t = threading.Thread(target=print_numbers) + # # # 启动线程 + # # t.start() + # + # print(db.run(db_connect, "select * from tran_order")) + # + # print(db.run(db_connect, "select count(*) as aa from tran_order")) + # + # print(db.table_simple_info(db_connect)) + # my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # index = 3 + # last_three_elements = my_list[-index:] + # print(last_three_elements) diff --git a/pilot/scene/chat_db/out_parser.py b/pilot/scene/chat_db/out_parser.py new file mode 100644 index 000000000..fb1ae770c --- /dev/null +++ b/pilot/scene/chat_db/out_parser.py @@ -0,0 +1,45 @@ +import json +from abc import ABC, abstractmethod +from typing import ( + Any, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) + +from pilot.out_parser.base import BaseOutputParser + + +class SqlAction(NamedTuple): + SQL: str + thoughts: Dict + + +class DbChatOutputParser(BaseOutputParser): + + def parse(self, text: str) -> SqlAction: + cleaned_output = text.rstrip() + if "```json" in cleaned_output: + _, cleaned_output = cleaned_output.split("```json") + if "```" in cleaned_output: + cleaned_output, _ = cleaned_output.split("```") + if cleaned_output.startswith("```json"): + cleaned_output = cleaned_output[len("```json"):] + if cleaned_output.startswith("```"): + cleaned_output = cleaned_output[len("```"):] + if cleaned_output.endswith("```"): + cleaned_output = cleaned_output[: -len("```")] + cleaned_output = cleaned_output.strip() + response = json.loads(cleaned_output) + sql, thoughts = response["SQL"], response["thoughts"] + + return SqlAction(sql, thoughts) + + @property + def _type(self) -> str: + return "sql_chat" diff --git a/pilot/scene/chat_db/prompt.py b/pilot/scene/chat_db/prompt.py new file mode 100644 index 000000000..5aefefe06 --- /dev/null +++ b/pilot/scene/chat_db/prompt.py @@ -0,0 +1,53 @@ +import json +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.scene.chat_db.out_parser import DbChatOutputParser + +CFG = Config() + +PROMPT_SCENE_DEFINE = """""" + +PROMPT_SUFFIX = """Only use the following tables: +{table_info} + +Question: {input} + +""" + +_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. +Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + +""" + +PROMPT_RESPONSE = """You should only respond in JSON format as following format: +{response} + +Ensure the response can be parsed by Python json.loads +""" + +RESPONSE_FORMAT = { + "thoughts": { + "reasoning": "reasoning", + "speak": "thoughts summary to say to user", + }, + "SQL": "SQL Query to run" +} + +chat_db_prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDb.value, + input_variables=["input", "table_info", "dialect", "top_k", "response"], + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE, + output_parser=DbChatOutputParser() +) + +CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt}) + + +if __name__ == "__main__": + resp = chat_db_prompt.format(input="查询用户信息", table_info="user(a,b,c,d)", dialect="mysql", top_k=10) + print(resp) diff --git a/pilot/scene/chat_execution/__init__.py b/pilot/scene/chat_execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py new file mode 100644 index 000000000..5e85c4981 --- /dev/null +++ b/pilot/scene/chat_execution/chat.py @@ -0,0 +1,26 @@ +from typing import List + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.message import OnceConversation +from pilot.scene.base import ChatScene + +class ChatWithPlugin(BaseChat): + chat_scene: str= ChatScene.ChatExecution.value + def __init__(self, chat_mode, chat_session_id, current_user_input): + super().__init__(chat_mode, chat_session_id, current_user_input) + + def call(self): + super().call() + + def chat_show(self): + super().chat_show() + + def _load_history(self, session_id: str) -> List[OnceConversation]: + return super()._load_history(session_id) + + def generate(self, p) -> str: + return super().generate(p) + + @property + def chat_type(self) -> str: + return ChatScene.ChatExecution.value \ No newline at end of file diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py new file mode 100644 index 000000000..6544c6f0c --- /dev/null +++ b/pilot/scene/chat_factory.py @@ -0,0 +1,20 @@ + +from pilot.scene.base_chat import BaseChat +from pilot.singleton import Singleton +from pilot.scene.chat_db.chat import ChatWithDb +from pilot.scene.chat_execution.chat import ChatWithPlugin + +class ChatFactory(metaclass=Singleton): + + @staticmethod + def get_implementation(chat_mode, **kwargs): + + chat_classes = BaseChat.__subclasses__() + + implementation = None + for cls in chat_classes: + if(cls.chat_scene == chat_mode): + implementation = cls(**kwargs) + if(implementation == None): + raise Exception('Invalid implementation name:' + chat_mode) + return implementation \ No newline at end of file diff --git a/pilot/scene/chat_knowledge/__init__.py b/pilot/scene/chat_knowledge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_normal/__init__.py b/pilot/scene/chat_normal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/message.py b/pilot/scene/message.py new file mode 100644 index 000000000..3a11f294f --- /dev/null +++ b/pilot/scene/message.py @@ -0,0 +1,80 @@ +from __future__ import annotations +from datetime import datetime, timedelta +from pydantic import BaseModel, Field, root_validator, validator +from abc import ABC, abstractmethod +from typing import ( + Any, + Dict, + Generic, + List, +) + +from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, messages_to_dict, \ + messages_from_dict + + +class OnceConversation: + """ + All the information of a conversation, the current single service in memory, can expand cache and database support distributed services + """ + + def __init__(self): + self.messages: List[BaseMessage] = [] + self.start_date: str = "" + self.chat_order: int = 0 + self.cost: int = 0 + self.tokens: int = 0 + + def add_user_message(self, message: str) -> None: + """Add a user message to the store""" + self.messages.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + """Add an AI message to the store""" + self.messages.append(AIMessage(content=message)) + """ """ + + def add_system_message(self, message: str) -> None: + """Add an AI message to the store""" + self.messages.append(SystemMessage(content=message)) + + def set_start_time(self, datatime: datetime): + dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S") + self.start_date = dt_str; + + def clear(self) -> None: + """Remove all messages from the store""" + self.messages.clear() + self.session_id = None + + +def _conversation_to_dic(once: OnceConversation) -> dict: + start_str: str = "" + if once.start_date: + if isinstance(once.start_date, datetime): + start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") + else: + start_str = once.start_date + + return { + "chat_order": once.chat_order, + "start_date": start_str, + "cost": once.cost if once.cost else 0, + "tokens": once.tokens if once.tokens else 0, + "messages": messages_to_dict(once.messages) + } + + +def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: + return [_conversation_to_dic(m) for m in conversations] + + +def conversation_from_dict(once: dict) -> OnceConversation: + conversation = OnceConversation() + conversation.cost = once.get('cost', 0) + conversation.tokens = once.get('tokens', 0) + conversation.start_date = once.get('start_date', '') + conversation.chat_order = int(once.get('chat_order')) + print(once.get('messages')) + conversation.messages = messages_from_dict(once.get('messages', [])) + return conversation diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 0f19bc354..fc8f6bec3 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -20,13 +20,14 @@ from pilot.connections.mysql import MySQLOperator from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st -from pilot.configs.model_config import LOGDIR, DATASETS_DIR +from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.plugins import scan_plugins from pilot.configs.config import Config from pilot.commands.command_mange import CommandRegistry from pilot.prompts.auto_mode_prompt import AutoModePrompt from pilot.prompts.generator import PromptGenerator +from pilot.scene.base_chat import BaseChat from pilot.commands.exception_not_commands import NotCommands @@ -47,6 +48,8 @@ from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot from pilot.commands.command import execute_ai_response_json +from pilot.scene.base import ChatScene +from pilot.scene.chat_factory import ChatFactory logger = build_logger("webserver", LOGDIR + "webserver.log") headers = {"User-Agent": "dbgpt Client"} @@ -68,7 +71,8 @@ priority = { } # 加载插件 -CFG= Config() +CFG = Config() +CHAT_FACTORY = ChatFactory() DB_SETTINGS = { "user": CFG.LOCAL_DB_USER, @@ -125,6 +129,10 @@ def load_demo(url_params, request: gr.Request): gr.Dropdown.update(choices=dbs) state = default_conversation.copy() + + unique_id = uuid.uuid1() + state.conv_id = str(unique_id) + return (state, dropdown_update, gr.Chatbot.update(visible=True), @@ -164,6 +172,8 @@ def add_text(state, text, request: gr.Request): state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False + ### TODO + state.last_user_input = text return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 @@ -178,42 +188,52 @@ def post_process_code(code): return code -def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - print("AUTO DB-GPT模式.") - if sql_mode == conversation_sql_mode["dont_execute_ai_response"]: - print("标准DB-GPT模式.") - print("是否是AUTO-GPT模式.", autogpt) +def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene: + if mode == conversation_types["default_knownledge"] and not db_selector: + return ChatScene.ChatKnowledge + elif mode == conversation_types["custome"] and not db_selector: + return ChatScene.ChatNewKnowledge + elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: + return ChatScene.ChatWithDb + elif mode == conversation_types["auto_execute_plugin"] and not db_selector: + return ChatScene.ChatExecution + else: + return ChatScene.ChatNormal + + +def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): + logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}") start_tstamp = time.time() + scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector) + print(f"当前对话模式:{scene.value}") model_name = CFG.LLM_MODEL - dbname = db_selector - # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 - if state.skip_next: - # This generate call is skipped due to invalid inputs - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - return + if ChatScene.ChatWithDb == scene: + logger.info("基于DB对话走新的模式!") + chat_param ={ + "chat_session_id": state.conv_id, + "db_name": db_selector, + "user_input": state.last_user_input + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + chat.call() + # state.append_message(state.roles[1], chat.current_ai_response()) + state.messages[-1][-1] = f"{chat.current_ai_response()}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - cfg = Config() - auto_prompt = AutoModePrompt() - auto_prompt.command_registry = cfg.command_registry + else: + + dbname = db_selector + # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.messages) == state.offset + 2: + query = state.messages[-2][1] - # TODO when tab mode is AUTO_GPT, Prompt need to rebuild. - if len(state.messages) == state.offset + 2: - query = state.messages[-2][1] - # 第一轮对话需要加入提示Prompt - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - # autogpt模式的第一轮对话需要 构建专属prompt - system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], - db_schemes=gen_sqlgen_conversation(dbname)) - logger.info("[TEST]:" + system_prompt) - template_name = "auto_dbgpt_one_shot" - new_state = conv_templates[template_name].copy() - new_state.append_message(role='USER', message=system_prompt) - # new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - else: template_name = "conv_one_shot" new_state = conv_templates[template_name].copy() # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? @@ -225,99 +245,47 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re new_state.append_message(new_state.roles[0], query) new_state.append_message(new_state.roles[1], None) - new_state.conv_id = uuid.uuid4().hex - state = new_state - else: - ### 后续对话 - query = state.messages[-2][1] - # 第一轮对话需要加入提示Prompt - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - ## 获取最后一次插件的返回 - follow_up_prompt = auto_prompt.construct_follow_up_prompt([query]) - state.messages[0][0] = "" - state.messages[0][1] = "" - state.messages[-2][1] = follow_up_prompt - prompt = state.get_prompt() - skip_echo_len = len(prompt.replace("", " ")) + 1 - if mode == conversation_types["default_knownledge"] and not db_selector: - query = state.messages[-2][1] - knqa = KnownLedgeBaseQA() - state.messages[-2][1] = knqa.get_similar_answer(query) + new_state.conv_id = uuid.uuid4().hex + state = new_state + + prompt = state.get_prompt() - state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 + if mode == conversation_types["default_knownledge"] and not db_selector: + query = state.messages[-2][1] + knqa = KnownLedgeBaseQA() + state.messages[-2][1] = knqa.get_similar_answer(query) + prompt = state.get_prompt() + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 - if mode == conversation_types["custome"] and not db_selector: - persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") - print("向量数据库持久化地址: ", persist_dir) - knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={"vector_store_name": vector_store_name["vs_name"], - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, 1) - context = [d.page_content for d in docs] - prompt_template = PromptTemplate( - template=conv_qa_prompt_template, - input_variables=["context", "question"] - ) - result = prompt_template.format(context="\n".join(context), question=query) - state.messages[-2][1] = result - prompt = state.get_prompt() - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 + if mode == conversation_types["custome"] and not db_selector: + persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") + print("向量数据库持久化地址: ", persist_dir) + knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + query = state.messages[-2][1] + docs = knowledge_embedding_client.similar_search(query, 1) + context = [d.page_content for d in docs] + prompt_template = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"] + ) + result = prompt_template.format(context="\n".join(context), question=query) + state.messages[-2][1] = result + prompt = state.get_prompt() + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 + # Make requests + payload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "max_new_tokens": int(max_new_tokens), + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, + } + logger.info(f"Requert: \n{payload}") - # Make requests - payload = { - "model": model_name, - "prompt": prompt, - "temperature": float(temperature), - "max_new_tokens": int(max_new_tokens), - "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, - } - logger.info(f"Requert: \n{payload}") - - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, json=payload, timeout=120) - - print(response.json()) - print(str(response)) - try: - text = response.text.strip() - text = text.rstrip() - respObj = json.loads(text) - - xx = respObj['response'] - xx = xx.strip(b'\x00'.decode()) - respObj_ex = json.loads(xx) - if respObj_ex['error_code'] == 0: - ai_response = None - all_text = respObj_ex['text'] - ### 解析返回文本,获取AI回复部分 - tmpResp = all_text.split(state.sep) - last_index = -1 - for i in range(len(tmpResp)): - if tmpResp[i].find('ASSISTANT:') != -1: - last_index = i - ai_response = tmpResp[last_index] - ai_response = ai_response.replace("ASSISTANT:", "") - ai_response = ai_response.replace("\n", "") - ai_response = ai_response.replace("\_", "_") - - print(ai_response) - if ai_response == None: - state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - else: - plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response) - cfg.set_last_plugin_return(plugin_resp) - print(plugin_resp) - state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - except NotCommands as e: - print("命令执行:" + e.message) - state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response) - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - else: # 流式输出 state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 @@ -413,6 +381,7 @@ def build_single_model_ui(): """ state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") with gr.Accordion("参数", open=False, visible=False) as parameter_row: @@ -472,7 +441,7 @@ def build_single_model_ui(): folder_files = gr.File(label="添加文件夹", accept_multiple_files=True, file_count="directory", - show_label=False) + show_label=False) load_folder_button = gr.Button("上传并加载到知识库") with gr.Blocks(): @@ -516,8 +485,8 @@ def build_single_model_ui(): [state, chatbot] + btn_list ) vs_add.click(fn=save_vs_name, show_progress=True, - inputs=[vs_name], - outputs=[vs_name]) + inputs=[vs_name], + outputs=[vs_name]) load_file_button.click(fn=knowledge_embedding_store, show_progress=True, inputs=[vs_name, files], @@ -569,6 +538,7 @@ def save_vs_name(vs_name): vector_store_name["vs_name"] = vs_name return vs_name + def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): @@ -584,10 +554,10 @@ def knowledge_embedding_store(vs_id, files): "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) knowledge_embedding_client.knowledge_embedding() - logger.info("knowledge embedding success") return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") @@ -603,7 +573,7 @@ if __name__ == "__main__": # 配置初始化 cfg = Config() - dbs = get_database_list() + dbs = cfg.local_db.get_database_list() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) From ff997230148670cc41207810c9ffe2ead46d9f33 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Thu, 25 May 2023 10:04:25 +0800 Subject: [PATCH 02/22] =?UTF-8?q?=E5=A4=9A=E5=9C=BA=E6=99=AF=E5=AF=B9?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E4=B8=80=E6=9C=9F0525?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pilot/common/markdown_text.py | 20 +++++++--- pilot/out_parser/base.py | 61 ++++++++++++++++++++----------- pilot/prompts/prompt_new.py | 12 +++++- pilot/scene/base_chat.py | 6 +-- pilot/scene/base_message.py | 13 +++++++ pilot/scene/chat_db/chat.py | 48 ++++++++++++++---------- pilot/scene/chat_db/example.py | 0 pilot/scene/chat_db/out_parser.py | 24 ++++++++++-- pilot/scene/chat_db/prompt.py | 23 +++++++++--- pilot/scene/chat_normal/prompt.py | 31 ++++++++++++++++ pilot/scene/message.py | 14 ++++++- pilot/server/webserver.py | 1 - 12 files changed, 190 insertions(+), 63 deletions(-) create mode 100644 pilot/scene/chat_db/example.py create mode 100644 pilot/scene/chat_normal/prompt.py diff --git a/pilot/common/markdown_text.py b/pilot/common/markdown_text.py index b0f96e15c..1d90ba645 100644 --- a/pilot/common/markdown_text.py +++ b/pilot/common/markdown_text.py @@ -2,10 +2,15 @@ import markdown2 import pandas as pd def datas_to_table_html(data): - df = pd.DataFrame(data) - table_style = """\n\n""" - html_table = df.to_html(index=False, header=False, border = True) - return table_style + html_table + df = pd.DataFrame(data[1:], columns=data[0]) + table_style = """""" + html_table = df.to_html(index=False, escape=False) + + html = f"
{table_style}{html_table}" + + return html.replace("\n", " ") @@ -43,6 +48,9 @@ def generate_htm_table(data): if __name__ == "__main__": - mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n" + # mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n" + # print(generate_htm_table(mk_text)) - print(generate_htm_table(mk_text)) \ No newline at end of file + table_style = """""" + + print(table_style.replace("\n", " ")) \ No newline at end of file diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 1c80d4a08..bbe8e5665 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -21,13 +21,23 @@ from pilot.prompts.base import PromptValue T = TypeVar("T") -class BaseOutputParser(BaseModel, ABC, Generic[T]): +class BaseOutputParser(ABC): """Class to parse the output of an LLM call. Output parsers help structure language model responses. """ - def parse_model_nostream_resp(self, response, sep: str): + + def __init__(self,sep:str, is_stream_out:bool): + self.sep = sep + self.is_stream_out = is_stream_out + + + # TODO 后续和模型绑定 + def _parse_model_stream_resp(self, response, sep: str): + pass + + def _parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() text = text.rstrip() respObj = json.loads(text) @@ -52,35 +62,44 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]): else: raise ValueError("Model server error!code=" + respObj_ex['error_code']); - @abstractmethod - def parse(self, text: str) -> T: - """Parse the output of an LLM call. - - A method which takes in a string (assumed output of language model ) - and parses it into some structure. - + def parse_model_server_out(self, response)->str: + """ + parse the model server http response Args: - text: output of language model + response: Returns: - structured output + """ + if self.is_stream_out: + self._parse_model_nostream_resp(response, self.sep) + else: + ### TODO + self._parse_model_stream_resp(response, self.sep) - def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: - """Optional method to parse the output of an LLM call with a prompt. - - The prompt is largely provided in the event the OutputParser wants - to retry or fix the output in some way, and needs information from - the prompt to do so. + def parse_prompt_response(self, model_out_text)->T: + """ + parse model out text to prompt define response Args: - completion: output of language model - prompt: prompt value + model_out_text: Returns: - structured output + """ - return self.parse(completion) + pass + + + def parse_view_response(self, ai_text)->str: + """ + parse the ai response info to user view + Args: + text: + + Returns: + + """ + pass def get_format_instructions(self) -> str: """Instructions on how the LLM output should be formatted.""" diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 8dcdcdf51..82eaa543b 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Extra, Field, root_validator from pilot.common.formatting import formatter from pilot.out_parser.base import BaseOutputParser - +from pilot.common.schema import SeparatorStyle def jinja2_formatter(template: str, **kwargs: Any) -> str: """Format a template using jinja2.""" @@ -36,7 +36,17 @@ class PromptTemplate(BaseModel, ABC): template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" response_format:str + """default use stream out""" + stream_out: bool = True + """""" output_parser: BaseOutputParser = None + """""" + sep:str = SeparatorStyle.SINGLE.value + + class Config: + """Configuration for this pydantic object.""" + arbitrary_types_allowed = True + @property def _prompt_type(self) -> str: """Return the prompt type key.""" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index c283df625..24a8d9a58 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -32,16 +32,14 @@ from pilot.configs.config import Config logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") headers = {"User-Agent": "dbgpt Client"} CFG = Config() -class BaseChat(ABC): +class BaseChat( ABC): chat_scene: str = None - memory: BaseChatHistoryMemory llm_model: Any = None - sep_style: SeparatorStyle = SeparatorStyle.SINGLE temperature: float = 0.6 max_new_tokens: int = 1024 # By default, keep the last two rounds of conversation records as the context chat_retention_rounds: int = 2 - sep = SeparatorStyle.SINGLE.value + class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 504b4e152..5cd8c4426 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -61,6 +61,17 @@ class AIMessage(BaseMessage): return "ai" +class ViewMessage(BaseMessage): + """Type of message that is spoken by the AI.""" + + example: bool = False + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "view" + + class SystemMessage(BaseMessage): """Type of message that is a system message.""" @@ -132,6 +143,8 @@ def _message_from_dict(message: dict) -> BaseMessage: return AIMessage(**message["data"]) elif _type == "system": return SystemMessage(**message["data"]) + elif _type == "view": + return ViewMessage(**message["data"]) else: raise ValueError(f"Got unexpected type: {_type}") diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py index ec11c49d7..0159bb2b3 100644 --- a/pilot/scene/chat_db/chat.py +++ b/pilot/scene/chat_db/chat.py @@ -13,7 +13,7 @@ from sqlalchemy import ( ) from typing import Any, Iterable, List, Optional -from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage +from pilot.scene.base_message import BaseMessage, SystemMessage, HumanMessage, AIMessage, ViewMessage from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -26,13 +26,15 @@ from pilot.utils import ( ) from pilot.common.markdown_text import generate_markdown_table,generate_htm_table,datas_to_table_html from pilot.scene.chat_db.prompt import chat_db_prompt - +from pilot.out_parser.base import BaseOutputParser +from pilot.scene.chat_db.out_parser import DbChatOutputParser CFG = Config() class ChatWithDb(BaseChat): chat_scene: str = ChatScene.ChatWithDb.value + """Number of results to return from the query""" def __init__(self, chat_session_id, db_name, user_input): @@ -77,17 +79,22 @@ class ChatWithDb(BaseChat): "prompt": self.generate_llm_text(), "temperature": float(self.temperature), "max_new_tokens": int(self.max_new_tokens), - "stop": self.sep_style.value, + "stop": self.prompt_template.sep, } logger.info(f"Requert: \n{payload}") try: ### 走非流式的模型服务接口 - # TODO - TEST + # # TODO - TEST # response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120) - # clear_response = self.prompt_template.output_parser.parse_model_nostream_resp(response, self.sep_style) - # sql_action = self.prompt_template.output_parser.parse(clear_response) + # ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) + # + # prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + # self.current_message.add_ai_message(json.dumps(prompt_define_response._asdict())) + # result = self.database.run(self.db_connect, prompt_define_response.SQL) + + resp_test = { "SQL": "select * from users", "thoughts": { @@ -100,12 +107,10 @@ class ChatWithDb(BaseChat): } sql_action = SqlAction(**resp_test) - - # self.current_message.add_ai_message(json.dumps(sql_action._asdict())) - + self.current_message.add_ai_message(json.dumps(sql_action._asdict())) result = self.database.run(self.db_connect, sql_action.SQL) - self.current_message.add_ai_message(f"{datas_to_table_html(result)}") + self.current_message.add_view_message(self.prompt_template.output_parser.parse_view_response(result)) except Exception as e: logger.error("model response parase faild!" + str(e)) @@ -118,17 +123,19 @@ class ChatWithDb(BaseChat): ret = [] # 单论对话只能有一次User 记录 和一次 AI 记录 # TODO 推理过程前端展示。。。 - for message in enumerate(self.current_message.messages): + for message in self.current_message.messages: if (isinstance(message, HumanMessage)): ret[-1][-2] = message.content - if (isinstance(message, AIMessage)): + # 是否展示推理过程 + if (isinstance(message, ViewMessage)): ret[-1][-1] = message.content + return ret # 暂时为了兼容前端 def current_ai_response(self)->str: for message in self.current_message.messages: - if message.type == 'ai': + if message.type == 'view': return message.content return None @@ -137,28 +144,31 @@ class ChatWithDb(BaseChat): text = "" ### 线处理历史信息 if (len(self.history_message) > self.chat_retention_rounds): - ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录 + ### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 for first_message in self.history_message[0].messages: - text += first_message.type + ":" + first_message.content + self.sep + if not isinstance(first_message, ViewMessage): + text += first_message.type + ":" + first_message.content + self.prompt_template.sep index = self.chat_retention_rounds - 1 for last_message in self.history_message[-index:].messages: - text += last_message.type + ":" + last_message.content + self.sep + if not isinstance(last_message, ViewMessage): + text += last_message.type + ":" + last_message.content + self.prompt_template.sep else: ### 直接历史记录拼接 for conversation in self.history_message: for message in conversation.messages: - text += message.type + ":" + message.content + self.sep + if not isinstance(message, ViewMessage): + text += message.type + ":" + message.content + self.prompt_template.sep ### current conversation for now_message in self.current_message.messages: - text += now_message.type + ":" + now_message.content + self.sep + text += now_message.type + ":" + now_message.content + self.prompt_template.sep return text - @property + @classmethod def chat_type(self) -> str: return ChatScene.ChatWithDb.value diff --git a/pilot/scene/chat_db/example.py b/pilot/scene/chat_db/example.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/out_parser.py b/pilot/scene/chat_db/out_parser.py index fb1ae770c..8378a829f 100644 --- a/pilot/scene/chat_db/out_parser.py +++ b/pilot/scene/chat_db/out_parser.py @@ -11,8 +11,9 @@ from typing import ( TypeVar, Union, ) +import pandas as pd -from pilot.out_parser.base import BaseOutputParser +from pilot.out_parser.base import BaseOutputParser, T class SqlAction(NamedTuple): @@ -22,8 +23,15 @@ class SqlAction(NamedTuple): class DbChatOutputParser(BaseOutputParser): - def parse(self, text: str) -> SqlAction: - cleaned_output = text.rstrip() + def __init__(self, sep:str, is_stream_out: bool): + super().__init__(sep=sep, is_stream_out=is_stream_out ) + + + def parse_model_server_out(self, response) -> str: + return super().parse_model_server_out(response) + + def parse_prompt_response(self, model_out_text): + cleaned_output = model_out_text.rstrip() if "```json" in cleaned_output: _, cleaned_output = cleaned_output.split("```json") if "```" in cleaned_output: @@ -40,6 +48,16 @@ class DbChatOutputParser(BaseOutputParser): return SqlAction(sql, thoughts) + def parse_view_response(self, data) -> str: + ### tool out data to table view + df = pd.DataFrame(data[1:], columns=data[0]) + table_style = """""" + html_table = df.to_html(index=False, escape=False) + html = f"{table_style}{html_table}" + return html.replace("\n", " ") + @property def _type(self) -> str: return "sql_chat" diff --git a/pilot/scene/chat_db/prompt.py b/pilot/scene/chat_db/prompt.py index 5aefefe06..2f6a83c52 100644 --- a/pilot/scene/chat_db/prompt.py +++ b/pilot/scene/chat_db/prompt.py @@ -2,7 +2,8 @@ import json from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_db.out_parser import DbChatOutputParser +from pilot.scene.chat_db.out_parser import DbChatOutputParser,SqlAction +from pilot.common.schema import SeparatorStyle CFG = Config() @@ -21,6 +22,15 @@ You can order the results by a relevant column to return the most interesting ex Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +""" + +_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use CURDATE() function to get the current date, if the question involves "today". + + """ PROMPT_RESPONSE = """You should only respond in JSON format as following format: @@ -37,17 +47,18 @@ RESPONSE_FORMAT = { "SQL": "SQL Query to run" } +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + chat_db_prompt = PromptTemplate( template_scene=ChatScene.ChatWithDb.value, input_variables=["input", "table_info", "dialect", "top_k", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), - template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE, - output_parser=DbChatOutputParser() + template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE + PROMPT_SUFFIX, + output_parser=DbChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT), ) CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt}) -if __name__ == "__main__": - resp = chat_db_prompt.format(input="查询用户信息", table_info="user(a,b,c,d)", dialect="mysql", top_k=10) - print(resp) diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py new file mode 100644 index 000000000..fd21f2102 --- /dev/null +++ b/pilot/scene/chat_normal/prompt.py @@ -0,0 +1,31 @@ +import builtins + + +def stream_write_and_read(lst): + # 对lst使用yield from进行可迭代对象的扁平化 + yield from lst + while True: + val = yield + lst.append(val) + + +if __name__ == "__main__": + # 创建一个空列表 + my_list = [] + + # 使用生成器写入数据 + stream_writer = stream_write_and_read(my_list) + next(stream_writer) + stream_writer.send(10) + print(1) + stream_writer.send(20) + print(2) + stream_writer.send(30) + print(3) + + # 使用生成器读取数据 + stream_reader = stream_write_and_read(my_list) + next(stream_reader) + print(stream_reader.send(None)) + print(stream_reader.send(None)) + print(stream_reader.send(None)) diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 3a11f294f..8dc3eaa3e 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -9,8 +9,7 @@ from typing import ( List, ) -from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, messages_to_dict, \ - messages_from_dict +from pilot.scene.base_message import BaseMessage, AIMessage, HumanMessage, SystemMessage, ViewMessage, messages_to_dict, messages_from_dict class OnceConversation: @@ -27,12 +26,23 @@ class OnceConversation: def add_user_message(self, message: str) -> None: """Add a user message to the store""" + has_message = any(isinstance(instance, HumanMessage) for instance in self.messages) + if has_message: + raise ValueError("Already Have Human message") self.messages.append(HumanMessage(content=message)) def add_ai_message(self, message: str) -> None: """Add an AI message to the store""" + has_message = any(isinstance(instance, AIMessage) for instance in self.messages) + if has_message: + raise ValueError("Already Have Ai message") self.messages.append(AIMessage(content=message)) """ """ + def add_view_message(self, message: str) -> None: + """Add an AI message to the store""" + + self.messages.append(ViewMessage(content=message)) + """ """ def add_system_message(self, message: str) -> None: """Add an AI message to the store""" diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index fc8f6bec3..fc55e8fce 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -218,7 +218,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) chat.call() - # state.append_message(state.roles[1], chat.current_ai_response()) state.messages[-1][-1] = f"{chat.current_ai_response()}" yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 From dce4291d2da70d83814721d64b20f4ee8a1facf8 Mon Sep 17 00:00:00 2001 From: csunnyS>~ ztUzP-d}+;v>+P)!&1LwD+x-rSY$(v0q!n%)YJ{zih{ftvo&5B-Z*1~*sw4>S6i38nl< z%% f^ z@T^5#q7f--ln1?F@~&m~XMrpit9&A)a+jb2UhCXz%f<$)) 8*UmAb>7ivyg68b56LT P4aHn@LqBC z-sAI{2~az2R6E!3MQj}05c7-7s0gXaBJrzo{s80ft6=l{iL3bmN&H^B`Jp%Yy>OFx znDc+S@qbSez&EcNITPUiL(}(FptYK&Op~S*L=!G~8z9I3Q$h3R7tKJec{P4W5YK#& zHA%2Ne+LO9n4u|{3gS ^q!MqKoBMFI|4^H9_WoiQd;1A9v2`khLD{WHJ*9_9Q z4Xh;zswW9=B#G$MjL?Hb=*&l$@JGN%B4s2Ep6VmvSCLx~gAKQ+o%zTklIT;-=mWRt z6G$|CJ{n0N8t5MLqB$Cq^cT*8fmAPhyt)O|odwOpFS_JdriIvFL~#rPacu5!oLX_B z$#LQfagxdLQVa1yqzRt{66D+y0*IodXSfItKx6 E|N2E7BZ1qGEvF0(C;&G zw6id^v&R;;xFlI2Dk6g_ISWoX2?05ri#dg}-^pBZsTy+8T`cLYbNK}G x8a z=BnK1sgdRBkQHbw=F12c7-$z*rxe(07uu#2I<*v{yA*nB7x~>6N~IJ9FCr?w{DAwS zRPAE(#p0xvqI|OAV!@K^`{M8SCG~=($t|UYWToZzrKN&pFvM^2$Q=mq62r3U?k~T) zE>|F~c!!&;wrTZ;4C#Ea;v%IIezRDadRTcyR)wllh1pt({!oQ2R7H?lO}tc%_fY+c zyoO1qhOM=l<)MZ{sOEiYE#Fe@yN6mH@;WJUe1Wmy@P2o!w3$+PxItfXs3CR!f={n7+IxU%=ExD;J zl}jzP RkQt$N@8tHmrS`4X_6^UDnbwXA zosN^#&Z~#^r__!c@~(58F3iVHY~d~(-EM-kZgj8i*KOTf58VvHJ*TNXG!(s@y1g7} zy$|Hw@7sD&y?Q@*^+-PU!liWk6x#Y^+IqAo`t&G%E9w5$5&muD)o;+&e<6g_u*=`i zEqGO%`0n1)Qz*At)|$lwYLHTpq&t|JHki3Qm`gEKs5?}eHdMJhR7(MC)P=RC!8(^= zy%fU(y2G%v;j!i6DT >9XQNFmf2Mp&c9S+E z0xK2;mD>t6*pp9u7n&eXpFmg`>GUR;RwfwQCs|%kzFV1m|1>F-K7}wfO6pBXtxO?I zIpV`+VB{v6%8-bn5c j$BfVnlmP!9G$sJ83%y%R4%h zy9WBZCK m1i6^t{&A>9@pw0w~8DO`W$yYAHyh5Mlwz;$!FeeX@#_Y)Lm&* z19@SwhGz`M!@1GD q&m;lJ4g3D$LTyL$ULW8J7<6!r8X?W zmOMA_ncq>-Q_ec>FD}DhZ+wDpko#O>K3~#^Tq$&3(q&vStX_TTywVW8R`b2Kr@FQ@ zxcLsh@-(=1%DnM~ce)D8uRsT)DQwvUbAAdV;xuvtDDQps?{hQn3)k*TsU9i~9%?fm z8rL3LI|omM-G@^zc-9^VJ1ge&pJp itn95$f{DTn5&)cQO`PCuHE|e zDdiQrMzwx#%4^RC_QiUu_OEr2FUEZVxTv!Tlj-_OyT`-s6Zv ee zi)~Ift0zMIXB!#esn52$4QJa;p;QLX%gv4lBgsb%QNn^Q`%{f7_AwfEx?yoJDcI>nVE*bTB7YG}mycx=xgVyzey~JWOoV}&VIFgI309QhA7rG_> z;SKV3l*mic?Pv)ineAV)WP_%$3bg1uv5JbMJ8@stWp)y@wFkc?>KUW&ChA9#?j{*0 z%Iqdvr3{*-*yf<`rPz&0EQlN=Z$F;8AIPzpQmd z!(rG_MaMM;yK5W91HxkZRfD}|AjhM$dX)8ny|!0E?!=6g8!VQI`>pRpjKXJ$2Hvz8 zB=@IjH!=KA^I-w(Ps^VM%+uDt{c@*G)IQ9_ADI>g?K7S>(a+LBaG|PB6jXVK9#rCp zvmOjStarWmqQcd^L{jqK`-wHMb~5fhUnX{BdkAtB(P!E?4rNy8a>1Be%N$``ox)DT z+-o*Yqi^lW977z-RXNHgFkU;?eo|tyt&!oByPT9HKe?Qc (jyHH&II#F!?AOS1`slvTAAoR|3Z^RFALLh7c)*W}JEJ^~ z%Qyu~#+;H I|ApAy!nM4vB@?0uh~PvM= SHoAaTM7zKLY2h7j)9ySo&4|?+&y+$*Kt-eL3}@p*Q;bIYi X?>Avzd{7DT_Jo7Gv8~=l+H7aR$>(lA^9)r- zIv@;xTharW(|-a)bF-=J-6URbIs}28X=q-MLM2%pV 9FshWP*;dF~=kIH8Z`Q zEZ^hgkLj624+fR}9piJ@^MmjE2bR3QM;Phmv+1DAy&Z~8sP+2BVIl+5U92&Ad8fIz zNn0a~K@i 5wz8AZe(UrB^MUq}Ic%ysB}#${UJha(wkrbWAtXHyJsT+zpUF)kC4ugt}fJjVUV zoU=`jcXs>ZhRtQUa|plfe&o>pVqxr@OLk&Gc5~a|6g8DXx5FqX181F#FCfTg d=PC~o)v9w!rqMR3 zRw<=Z=$6K>nvKC*75KW$2s2z$>2Rv(OW>hu{^2}S{i-@Xf!&nJdSU2gRRq3(iI%9` z?0D3VIx_|i3k5cW)7piZn&XwiNy^;(AGHR~{1fXt51pl!9}R8gw)Rd-ZK-B8I-wkY zjH0cVw`pCQhiqwC3yT+EvV1MmBV|9{g|8g^V<-O!&s;ol20oqPQoqT;X*DLNe;MW4 zzMD^~?HjbTa_`d7m0#&~6u$Ow=vr~TWAF3SYS0&G{^rEA!Y|cx9huIp8zXAZ_-R(@ zk8D%R%F3IcOg0->8g9KLip#DWprY# M2#UH-YI5ht_S^bK8eVPEsECD3tf0@l7L zrr>xWCASYHdrrT$04&^&JMp#mj`aKciTK3j;1AYjoYruU;nY#?LeY~0<;>;)5v+q~ z*Rnmz!u#ox#P=n7D~H0pDKmBI5e^H tm3^YdR-TuCqbhZ_sRl;M;j+tKTb?_t_5v^iyUM+nt?S%OO(QryP60=4yZn4v zlGWP33_quB;&Td4M2lSJDS0hg65T`JEB=l2)IEeN?eR}n+Fg|?ec1M4sA|)D=>iG= zaBNED>2>`Y@6(Ofp1_{K(%hfxrbh1tbe!sm1fQGyqYtNEQunJD` ZL;F#E=jbn?fkv(<2Kz|p~c?bzNpguUJ~W~ zbY{oZ3(cF+cKxg6wT>GToQE^l&P_$`t5e(P=jBX;ryzUyruV7OdgYq7c>mRQQsv|0 zmM{E%ZF^UZ)ElVjJt+kSxp|{#g0aq>XnWw_HgCLNqWNs0`VlbuedO~L>-D@q^*t^0 zdrjhdkmN@N^`SuH?IE`BoAZ4v^k;$kvoZ6H6uO`>2O#kWph^euSp+yYg46RI2yel! zZUe+g{3P}RC| s`AO*(tk@K+!2Ht}O2(*Z zsAmzP?-mjo5^@3zcuEY4mJTv!4vh&3`UU;85d7J;DAa*JIK3&<85&yX6v)02>|qg> z9}?#07FGfY^MV8`n*2Nu`WapnQjgewZ=Clv 86u_jKU8{ Q!+R!Wc$W4A%G!p5TqXMvJ{=j-?Watt5_|4T)XQjA3bx1U1L99>l&I z6j}}O$F_{a360}(H{6B%`fv~>68cM&C7u^Gnnos8$~{`zGG68moEsv9%o4{=8iPWb z@Kz>4gC${)IDARPOE5Xk-~gUr)U1M6{7c+2@mpx3SaPCcV~|8+yuDVeLvj3fnIxSb zvFgbQoC~p mWH7JtY%8p}QubIS~>vm{NL>vLTrmy8x+WNv$VMZagshE$wB~ zl P@I%+nFe!D^L}Cf^EPE_A!SA;{ip^K6`EStoZe8J zzSb-kPnwo2ld)x)(aVxnQyA?(kalvHHbM&dbC7mHnmLf1QfrZZ=bnB)n0ZI 43Oh;nPqL@nD*MetYcw`(jXT2nQ^_wO0N+vT$Ae-nugE%FPh99zAn@P2pNrRE| z_$&Q|RSt4k4r)ry8cPP&VGizn4i0Os(?T*?Ng82TE{S$FxmE62@sArB2rX+KonYP! zX)0@29!g0bQ%asZ1|vE~?iaz_cZ*pn7)kgk*;*~Rd?mR8L%G_A`NwnFdRloOtqQ)~ zCor_+C=L}^9~LMGI&iS&tH~CsS`|8ZMDt)2=nWMbwB(B%78+p`t)t{oh838H6$Q&i zNwwtJFBUmq6i1ijDO=?yXcxML6(^E?*Ss(Ma9Hd~R^&aDpDkO$z*7*GQdAsPQeRZ? zYq6l>u&9!?IEx8hn37VQWL4T6R;pi6n2u4Bi%}N9T9$cO`deH0r))t QUH|Qa u2q?!N6E-g*~w4^bgraU zu>7xf`Gj_*RZ%LWpnPSh^5L*@*-E@!uwqZPB0sDGDYXKnG#h$Yf&EbNQYQ+I4!)7C zzO<^o_NeZ2DPJ0@e!{4tTB_drU5qAIgIJvJ_NYO%uAywIz MV6? *8l5~f-?!Eals3A< z8udJZ){syWwmN^gCS&U+aq&92)H?gtCcC312UwF2B!x$-F~PbqQ7Ew$y#l|v(Pyby z7qca^mEXX#DL|*?r%pprYq kx1_esEVa&Ax5HZ7$Cuit9@^(0T9;GX zw}m=ZbvpJvJJwS>j$1o6OWRK{J9o)DPsuxO*g9|7aIx3rm`L)D-VvSvyWY@rUA^n# z`~vSf``(3V(}i8u^^$@L>#>V~q8m4@>y>;raalKTxEp!78||?hKcbtit((%OhuW)$ zR=0 0~l%khjxR>p?hXt#bpS@R5zE{|$S75nE%&S+tt@q<{@8`$fPr`k2 zSiNpX&EANFDE5{M?b7@o-Nl|QS`mm&lW}M% V-S zVlYM+3IG5x9R`yl1_4!ogvWun;{k9RG#hIu^LQ}F23k}GEs-B8l!rpFU@7cyShYN? z)&^D|0c$LSH4npDk74at!=3EI-SWe|Hp9Ooh6l=q2Zx7Y$HOC7BV+6%6Y?WdHX}0; zBXeaV3uz+~X8 $2T7;Ds9cyxmux~Dt3FF$%f0X-@k{j)r}`#Aas3xELt-q`@| zuz)WCV^{KH7sI3b6ywMr#^LPa=epx)w&TfdW01#jT-*tQlkr#B6GR*nBnlJ6BNL=6 z6XYioG}x1L y`bW_&_un3r1FL2kp%#FA)|dEYh@u*VKG~8@mu5~ zrt)Iq$YR#YV&%zVHTF_1$5Oq*Qse7UftZPh;icA-rFQJ)PLAbnh2>t`<=>IZ1Lezu zBg?Rp 2Tk(H&Bl@;vOHICH{h1D(F)t!|UreCOD5i`e6 zgD2R4Gmf=Wg|$DnYv+1vx87?P?Q3@ v%6W2;OYGQrsZ6+aR@reLMq5fKh?Kf85S_fdl{n*iis9079f@Am9Z&i#p2GGQ*;z z!se*Mfql%mT*Rf~)aw-9mE|DKP8JQ#a~!l0%p5OrmMzz=oG){Mop^MvSGmc*Ngdm- z^BA#fT{~;?QmkA$y={B`IsX6$`}qe1`jPz%2@R7C4v&f!;*W`oPe@FHB&VdNrDtSj zWy@ga &lTpxEW)*l1FsmjXmh{#HL#Wu; zZ#b1dbbn~DmI+^1Bmjef*0N|KSIB}^Z;1QnWPx&ajp;de$xNwc?Pgg?z4dUqes9=o zxrVZZ+B1~+^{|HWrACLXp>!Q9)|FPbi^COIW93>W7zvY7zNu=XH-vyqf4Hf7YaoVB zEl*w%Eu1&w1f7;S65*c;E4%T#P{xjLG!wO$)-Z@oEP?+vF?>S(*W*q u`T8-u?9e+|SG#frRQ78$U_R zij3Y7$0$z+RaKj*05I3jXhMo*0OD*Q`YxidTwCfc5GK80SLoXu7a$UlunKG`P%tY7 zz+)dbiYfmwKNKz ##pptmfH=9q+r0A!U6 z0qtlyBEy2z1ictB =4G1{hSEi!~NWtVV|@# `gMu^{49V0CB3|>NyhN+R;-Z3*!;-Rw!^6_5e%7P1x>>8E@}^zZqr{fj7wVN= zC~U`7eMHvB)lk+3^O|AaqvP7~&uk}k)9ThI@cMbP(vt=_2h3J(wdU-oaXZoaPxF33 z>7SP4hNC~NXZ>uaZI`pwr|q}9rKcTZpuy8l04lqKBlu<8Xp=s(;aShdohNb;mK6JW zAEAcL`EL^Qvh#il*W>d6>R|T2Q2HdBzk|$$Wq*g*8;}2XcC=IARC3MPT#WGVm0gSq zEBQH$VgAIt90xziUrtCfm0wO0lOm1*DaRc;PpfL!Ud?FwzW*_+>-u4TRzH~IdfqfC z9ev)cu>3l}qVnW=$zg!QmB@C?_Gabp4B{A&*Ud@eng^WYc767R;_XHVQw9Hc5a*xU ztrnh_ciZtAb^^Vz1{HU^Y2QZg_OgQC2tL$jMBCjT_$Nf&AC@)#X_>AZc=KRbGG_O1 z(zHiudeU<9=Ru(j{^s$lkJ#Sete>g!@gtP;^zmX`s#5o2Qp4U`WZIze>3Z3 2~jQ@ag{M6!8R?QQ;uzF3`OZ7)3q{ zL=5e+psw=6jmRRQWbVe$KJzCL) ||aZ3yXK!!8}1famh{Qt#0W5)k8_bwBR6^s6hdp={8 znD_tW-p`<51oy%tBBP>Xe#QPb?m6IQ<>n!{S5)lhTT+hTUUf}vU427iQ*%oke(V2d z?#=H3|C4*4jL_cf4i@2QAe)c}VL^kla|g{{D!pVcu2GexWesuQp0$>KA!BD8Nim<5 zKL*b{G4sI3s4M#7h7rYT8xP)}EL6^wO(#PYqf^(mVtjD>wUZ6 XphUbfuU>PTy&z0)_}|Ki^1?vH I*MM-a?e&61FXSB88gqgdXg@HrvHTgT z$G0N8t`NB5zINie20kt4TMB5^<6jQG;J952M^)tCiWait-;R@uLIhOkRJd(Nx+1uj z;7KXCpB%0)c#s-haeq)+IC_6nF_3Y441wvxg-%ko>^v6tzKyh+ {2YN-$cB|qcuV9H-gs*Sc+a?724hcmtGPHG_nWm!9S^&OobdaLMg#c6 z&43d8`T8Qd6Mlb`34rr;0r1Zd+%vphZ|r*H!3E>68zSTJi@!*$^e1V{LLqGuM|HvV zr;0a3W8s&;q@{72R*C$j3_)=3Ena~HU|wIkmu&G&t6*pY&)EW*lBnCA@A5ZRup2Th zS$D`MF=O0+anD`MD~GUT{L)lW*XF-L`s<*t$j;J526ml!*gVLr%X+W zb}en3_9@DL-h_eAQdq&!ktE|Wk6B7)=xfT~_#*jyR*k_d y zQK7HID&!xK8Ffs#NJDjG5S$wvb!oXsKZ__7-jf-#E2_=7mEWP}=!UM)0NGXNnFKky zUhs?fnL@j^iEs~@X-T+oF)S2Cs=-he05>*P2uOS&28#pIZ5>R~U&a%NuK3f8oDp#Z)fGr9lzdGZnl2E$ zDprmxQ7x36Dbc v2zZshA5b?|R!7ejIm78Y+7#4lvw;{yUneP|8u8xT;Gxe058q~h7fsB-y z@xc~Gx$0{&u*)r^ #HKmZOqw9XC|*3nnud)TwzPA_tzg9 zp2{79<(4 7TqRTzN@ zhSyPqZo6@#ssbzP*U@!udq_sBg7}6v2-KQ-sb5xyNXc*Fr$U6NM5@CyhBt{?Z-2jy zk`FVH-y&PO?dKn@jw0aNdi`)aAoj8*CRpCo^s)&mtymMAG`!6yB!DaBQxjh(zr(6? zH>4QFmDo7E^TzWIW*AhJ*ek!wm3lX9p;();1)E1CQ?$8!LafQkSEb4JAWL=C;$7+@ zix6~=NdoF{kre #)Nbb5vM!i8}QdgLhCnFtg)F_tQHT+i+!$GCQ-Df32KlcUj_!?1=pB`oB{`K4Yi< zRze<_*anO^#6SN>Mc{}PLQ8Cn9Fg|VilC&l3{epT #O%GBpx4AuW-v8{nTx zs5zm@A$sHm8p<>Ku`;#T04R?(j6$X+ulz^XBAk8Ra9k+>8JC(J4Yw V(;x5e-Xo^C1j>iE?ua#c%8#u zYL*`)hHJrEmePQz2=@0Ejxe~D{ZYuV^7`ZuO32D)N*J4-7Z8LMii1!>3$uM;)@>*l zJHKpp5K5?0`@xJsRTQ}}g9uR(SW62hpFbT+b+0wm=pU_)c;cw=G7^=|PIdKwAg?|D zp@gtC6ex-NevY@b{5WEkds %4Cxe7r*_ zA&N4D61u~n$r4wic`{%3n^pTQ4@Wm6 2mHb{uA_GyC2eI{r^ zsKQ7#1DHN*(uT2>NN bu)<;duGmy62}+V4iO He zJk_-$**wkp+Ojay4OCK?<@Ic7krn(h#UeX|x1=aHLWI>aFFv-pFh9_^rKljrQLwl$ zp$Gl2C@0aQxHvhdrMM)$=DxTzyJL`=w3`PYM$)k9aYWO)J5*ZH^~_pU*>`PaQ}vtH z^SBDe2t!0xiLjN|PV-vZ!s}*?Jx}VE%wXjW>ppB1joU8Pc1=4uo`0H-3Sbp2XB}*n zt(Ohf_H9?2o~LaOyD-iQ^EL*MDsr^)Sr-Zu!cTxqv0T-I^Zxj(4_Af#`)?A1vhRHq zI?LbtsqBx>q118gj)TmZWsXqx)MdvZ&dTGz!*8e9oksZA%AAIUmzJGIMNf|}#>EIY zoF~v25lTp&?*yTQq&Tkrt%S_W5ldmN?bslM5^@m^cP2iYw Sl=}nH;?C5(Mpy5Nr}(uH3B0?b$Z3Cg~_P? KHbawk=#JI5ntf5;c{+AGjyDvoLDv7?*Kdw A^ zzDW8*$=9Q7;MaJL09t7=WPGsS7BXU`4tHJCiUo$M%VSR#6AnZ+1Qg5fl+adXYk!B% z*0 CPDv7eC`ssj?3O7J%^tzj3Z zgTTUzHvp%MVH+ks*x;yDY;7ju%<3o1#fhNnx&j4TOAWcj*__1s${*!c-!PXJI*7b! zvuHs5B>-R>av3wBvO(~}+$R1_Q@_F+$0S5Wz;)9+8u`W kF~$N9Rvuq3n(SS% zooiixz2a)ZDspCY>N%%Z;t@B5s0bR{&%7(Ww&d3C77-Oed!^5JK@ R`#o4g9# @Qht-G8HhI@5ec%>unFjReqj0dJUW=&gnDWo%KMosDxqvEfOUFM zd37%?pJK;HabZa9c`u ?10IaS*v*0P&OIXQ7sf7$10D9-YF0jLv I=H!WB0*xSvl@ sIpM5swXj@^Hmu-jU%fz~2KBdX=g2J%gGhfVru{ni zU1u4yIcR(Iq}Lh8rFD h&mJ$4$BY{q;@xoWq_(P$<9^ 2dWRPu-B=*Moar1P9HMvMjm`_o^LvEM9zXaA2)TP zPUQ HJ|lb`o;Ts 1>9^ZNi@Xdw$8-x1hf^2Ss>_~zUrU2|KdurrB zi8_ExqO}}jKwVI
i?@c2s9Q*9Vu%;Bjg&=T%xy?4NoYK?C7rp{F%t-d z9zZgI$7KWzpAyBOrZB-Heqj_wxPv#46pB5?53hiRJt2ix&WGVb_-Zu6Ye~ZEn!;P= z!<%oz8=>K(P`^G4zkdFRfq6e@Q^e4GL}lWS;l2 w2V%YNrno;JWf~K8wjX896n&8Emcr+LYvFzu;&xLM{S1wMu!sht zyFa1Dz?oxEWnz%D+)) %@4~^w4jx{8S8<2`~c8ha_#PK)B2` $6gJrm=;x6` )Gy)=E(F!V3A|fKv%+TH4-JL@Zol+{@-6dV~eATt~ zcI~~^-s|kM&w0*sp7RgPFZ24|ulv5P_vigdIUGq5J4v}slM0kmfHD(tY!l_8Q{}5t zA)TprMpIQ!QrXoog*4JQtkb#yti~iw+dEC`Buxi5{V7elzDT;kXqpjMx-k%^uuV5P zNjJsKFr&$^6v?m}O}FOCu+hq}v(2zQ$*{-Gbfn3270GlT&G5v{FoJz}Gn6^dmKm^` zIWL$MsF)S3l@)256-|>BYmyZgot03P6)%#VsF S+$GV?20*Jin*$bxyOomPmB3!ORkBQ2xylG+5xhl5|Qc>v9Xfd zrzH}!rBb4$GTNnbcBS$$rHa+1N@JzUr=_a2We-Hl9%+}M)a=SMV#>6t%XG%do}QNJ z(Uw0KE!WpBH?S)=iUCqt<)&lhW~b%mv_OEW!cx1!%C5pXroy(m!hWp6@w5U)i*Oc2 zxN0NZ?GT U87x{Es$CgwR~Z>o8C_i&J60KgTA4^&l`LA7s$G?C zSCtu4m0ewxJ64r{T2)9}T`XE%s$E@fSB;3NuBxuC8LO^4twz$;G>X ?*7VWV4v5x%)~+42s~wK19jyioDz%fRwbQh9v!ZqL+I5R|b;~hztJQVu zV|AOSb=$P{yQ1}b+Vux^^+z%FC)M?5WAzuO^(YV>P^g52=pZrek=U_F+!`eQIFj%T zNle#3D%L=*(?Dt80F7;+sc8V}Tj 4vj@U-7nns|eh4-wHpRVbe zSQAjqB52 ;(iEUP_X;vC ^_A9$o8mu~vPZRs;K1qu5rHnpV^CR )1Binl}6KHpjCz7+t%wSi7rEySsh6XKcH7O}p=SyZ>2xAYDhWSVyQ%N4R}QWNb%t zO-Jl_NBmhwB3);)SZAtEXS#i7W^5;Sc1>sQcxV1uXCYlzu~=8BPFJ~o7b3Q+s-~+3 zzl3a{3rW|VL|f2!hR+20d4Tc`7@!#bygK>5G7bpH`2Cgfgm3ZBtCL>`CI4V`l9-hI z%j)FEpya33$$#6*m{I!sprj*&<6${qlWd+^O%bm;P-|vO3;Tk^!tvBwTPWDV72& zWBCxSC!M8PpT4h76wAJ^j7=sw%W{Sa^nP6#8-2=t()H`g_-|JyTeX3#F4MoaGA;vY z^yz`g2vh`MJxW!oAJH)RA<%n?Y-`aFd}27i>%I|2;oY@xbdid{Nr->|j93pH4#5+9 zr2%SqXNW;W`AB1W5#)+TgwM|y7zejpm4%U!=vKeC{yNtmad=7R{LSIUNQe_317Q1J zT3hSw6+d@Go##<8I-t*DO~cWc MoKR)qsIPDU^U#3}*9BA7P<2W*h=gTZ(Z zMqr7+$_OLAWrCbE(+FqMGRxN#hCnVE=}Ca=N!XjwNk8vNpiApdPk_c9B4DPNjL7t7 z%Ql=4Q`^Okg*Q%^R7MH*geJ?xZp^ru$ !mAOgHgf&CV##&d!Ny9p%g^801>Z%$n5V%BxxwxsqEyN@H43L->k2zl9ku zx47}zT~ph8p43R zv-$KX06O)Wo$ayH=66Mr<-UU!*Lr$&R<~ z`YdS{wDp?pj&B0>9OV~{WVZZ%KRp>1x)z~3*Nag>{`#RW=(mj8L{@>v*d7Z53fql{ zW`OkC=OkvYOm5O^yFqu>pRu|{%FC-+1WR`7lkjyYC|Cst$Kzk-xRcc(``jWVxXKw* zh0X2p^M 8wd3=hD+PM1%`}yoa#dU# zjM+@xak>WuLISNJse*Gt?Pj>UF>w4P!IF-9=KduCv)ma?je%88Bk8xc_J-_6g?ryC zC&le4HCAtH7CKas#g`~~W4~RN|D08DGgg#?pD0_&9jHqgTp#DfJLx;VrI=iw%8}?# zP3N{akaWpynWt8u-zR(^I9Y6k^WHIFgCGw>rMhO_eaQ7Gx;Qlbre79kNF>s>E+k>a zYe{?!N8WMqVWxa-#C q2r;&kogE-Lm*WH-4%O#1O{v8yGy+WXm$v1m04DFt($+{jg< ziqUAsJt$mamBOkIe3TG=CsXnzm&8lXl8+`l&m(X3-oHhwnwFR?7{}FnT~<25b@xlz z+E^w(AO-UBXD`ce&a;*`>C?Zx7v-##7w_Jo`gzeZT}0N{LOok&USuf!qZ{wLjMITR z#fqx_mM_%1Q6mnI!%LNUv^8#ZX>NJ@R!8%o-Wc03-yNKNh)%5c@U>WlNU|(h1c~gU z6~PL%i*$-xbY(*t+GDQI8@GQCoUfq!{HNtn{cUvAZGcqayf~IHg z>2s+B8DKPX7^W*L>`Py+f<70>@5cBH{tIZaV_jh)kCG7~R#l$^W@+-Vy%7wIe@`m< zoq6CpUN+;q14P)!SgS`M5V-`{3PHqzy69e(4LSMjB3PRBFmCxlHO&`KHDF)Of9WO^ zghpx&0~uYXdA;BMMnsM &PXNp9(~cTh~mZVP0K5y8y32TGa2-!oF%jSukJT5dAvQb0R|-*i1x?lCTcgU zz7(?Ablx}HecU#>Fmq;3GLXR*#do&&5K&|m00UY~M5iPUD{Csx-X%>{FUv59)P^`( zrX9w1zBB;Z3UBX+C7iADM^trcE2H-F)Z^CFbW4ZY^7o5h)UFxyo{Tj4AC$SAZ|DY8 zr`Y&^t#aW N@U$Hkoy2X*9obETbs-H@@!JM+Vkv~;N3GW{cDz2=&ok8@b!gY^ zMplq7+;Fn#F|6D22DWpOe8;`Cbzf7b>{stQ*?#hH+|LJm-P+=(L!amm$|7!UegJ&k zV;6^2DhT3Tv%}eh8MJ2cM(?!Rv&HME;~pS6lEHVr`Ve(8*n9h+qTqb<9qM$FNc^ac z?_xIub+)J?elk>WaZrLf-wY5xTLKo?pMcGEulU7bfrA|;!%_B<9qJ|wj19{*gcYg7 zAl5LvNEm)5jQ9vfx(q{+N7|8Sz-v5UltOS?PdJ)2kiT%I9)>X=!RdsY89kjd4ehyy zoq3O(`Jpb?gj@tPTm-FMgd$x;5H4cFF1L?dB%rQRLas6zu5#9{@{z8J2v? }GsCTlEcdCYWy0v#^q<1#LJ9pSS|H!)#>QgM_Q>x)pZta7J z^r=Gl)C~Ke>W+MnP~S!&-)0TpR%_q(NZ(F`Z}+fo&yjB*)Neq@@3V&Aptav{q~9pQ zZ+zHq^2l!*>OU*wKd<4xXzjlo>A#BbUmy10Jo4X$2J8w2>}doXSO*+M2Am)Q&V~an zjsif`foQ^k5Y0eLn?UTSK-|he{E XZ5KUze-AEAqaS$VQ zFtczlt7b5pO)y7PFjr+T_ee1BaWFr1$Ti^*0nHFWn-HO>5Rpn?N)mGWI7EUvR7yBh zMl)2-CR9EuRIxHtX(UwnI8>E7?16CDBh4^1n=p;2Fs;fkosqDo$6-Jl Ie(r2usZfE1L-Gs0iE22>X!;$KwbXb)>U!q^o9R2ZyeE zRHSznQQtgfUo)Pk0whQn62b`y<9sb=9u=Ks9~%`FrwNH!iAn~N7o1UP3Q-xxQCVJ5 zslw6eE76%aG1;6kxyR9YHZkQ xj7$FZf<(UQNuU;xE53FsXNg$BG}fEDmxVOan6E4`Mlr>y@5hPCm& z`sZNSzpjA)zw@#|{q<=Rt5EXO)8 wSY^%XqK>81|VK(C$ST z0+PjK=f1`bz_8a?+F6^m02uZi^T&7qhP_wyyHMCS81~kqU4rf!&MU-Mg{)2STff1u zblb4STs(;c {l38nOWio3~S8jG2>2W zwG2RE?i{n8tl>sKLt*@_9CN+`gGO_HLW`Mm{$dBqa{&^Vob!P)l*aQx@@zk#u=x;G zInIU9M{33kVHyTm3*kCeD+>{N?wpH}`r*clQAX)mi_xa#D~mDat(;4-mV?Gian_4j zOY!yxD@zHmKLf+AuKa+){s9d80}5OF8HQ!c`CBM#_6H363l(-{6Y-B>*t3JR?=b8( z5<+FV-GI%WyWNO?bA7vs_|L(xUHom_U%Lgq!LUnv7)T7( wq|?^M;WUhl*Lud2A>aBZ@&gkj1HwfQ|3Q8{ef~28|2Ln# zJTvDXv!K5)@YZ&m-x>J7|LotR`~6@+?^uR^XW-o< P&vt(3ju{q0A&gyNqv*hH}OP7 IW=6oIe1%Ow%7TGnI8~$A3erk z2I;YAkZeYNvAu2b9_b)yK_L;ONwR|GBTwI=AM(hUktOv0Ex;t9VK6chDq+s)907rG z8sNT2R$X+~Pd84FHp*Wyh%t{gPmeVnI8Kjy8_u5*|9;mdBf-XTBqPy@l*TC1g;^^z z*<;g~BR&v~W+f$9O>rg7S2;Q>E#~>E5x|RbnIxz1YMG>@rkiAEdzh^JmVy6+&)x}& z{Cfue|Ao)~2Lpe&{ k_%{9xb@$1R5d z2Htup?>hs3I0eJvJ(_l #e11D}7oSpzWeTSzRv?+pAeENJ)5?+iS^f__$ibG| =NsL^L_Z(O14*-S`(mJ%4wg{S|KiUw4oBCDil3;y{agXJUKk_mu^tuZK4693gQu z%}nm7HS38riI{Z)IXch#sqjuLaHyII0r`R+n}x(To5wOh{UZIzY-$1_Y%wzi=$y>U z={Z?VsOPoYC>qPR+pPH%f3jKEM1QtXe|Y cAj0qHjmi+@@1 zV0}vq{&G{oL?K M{9>dV`$|~SS zl#ZzTm4*FFDliE-@O!Di$j>YPYAW!j8_^%VgZ`Z>p8wt((O*)5zkLW52}}pnf*LLc zX#`Z^;FI$ =y=tJqcr^%`UmaSZLW<#Jd z=oy%l)jZh<3S@2}`0|Ay5TCc2qz`o{2?vG6dSuwD6_@7582~NYoRFo2n21MN3APfqmf+XPYs!Z3X;pT^fDU|rZ z7-YmqWL(6P6+<))Or&I->r>3-Pz^Rc!<3*K!jfn>j{t}= m>%q(OhQa z<7I}wPX+$Ljp( {}}EPu_?Ozo!DX&JU+O_%4p- zzo!B-3NC Iwx3M_DP2-WFB63cEgfIPPA(f5Fa0bL@P1d2rSZ-mRi z^hj7aAh`5MXQ+}CW>8@RizUA^4L%c=?6M@Ov$-n=_cL6TA1o}Hz%zVxA!({XVt0zr zXN1NGY1#u~530}4h;4*q7$`|R8ETS9oe<3|f>*tie3Qx2uv<94t$4_vqdWh4kpeOS z-_ASyZ!g1dX~06Je|Fwo|5N8(FyvLp&nf7i&bvGqQT~rK;PbGGZ|7ZIa%KJBrJzF^ zKK)(_+Uv{eznFsl|I>NrWFmi!LE)U`&~KivSXbT&4TT{;1v?<|On{#-wI?`VT*Xdw zL8IiDfs?Um%s5jCM$Wj<%E<$1KuTxNBu|>)_9QQQ_sm((E8)YlUVQ28v%b6o*eQP3 z1~cdUL;;~-z->&NnZQe=Udcgxtytquih?VtPpOe`WJtL&2rsONDG-iHt>pF6R88x6 z;T4Th$3mpOiC!9-dUBK&CJ_{gM5EyDOONp=dM81NgAHuG$stJAcxQz~2t?M}%N1P# z3LL}O0)fbISKjgl3QB9Bi2_f*ZwU=YjNd5DX~P%-l9ew+Vh}~z%7iLN^hw0M=TbCz z4%WL5CnnJVWu;V5wx;IEBXJWWw9T>%iVZPOqcT5`GG><~Ud~x4`|OVSI ;aNI(4>xnt#*z3B2O}ooA=4{%KdO2dT!uk`o z@q$JSmYVG*g6sB$%~DXEq9zh)v7%P0hf1&(h8Hz&TUi54E!yZj)~~j6`}kUP@Bt}+ zPM+@pfH@!lATZVa3;QofHBr(pLFTYGAaVEE`X}@|Pl`S>JUkaDJ(Ie%Klu1v-uodv zC!X>z`gTvOh7AK=AB-B60760YdY%uXmeo%_jDG;4(393*^MEuUipP4=_4JALv c~0HqFl=p06ati-e2b}`@8! gGb80&PdO!K(g==bHu`=qMOS9B1o;*Y zwYT1C-OQ8~94)+0`+c+^EL8RKa%6Hd&XlQ)G((xKg3q`8j9po17nWOvMp1A5KjgB| zA7-}Q9=Phf&fmDuO;tZekrcS~FyRbaSX^eHHz?3Fm6ckd<3P4J*r#-!^@f*%tnIa+ zMEiNR>k6GJ-o+tFHLUFFnw^jQ2ST$=FOvl>++k|xdRI0-$M6+#_sQ;VL>;~XugA#U z=V+fI8@U1Y{}@8(P2x=>F9G7GYQ5%P6g^y%E|_3^&&=Rc%-DLmP`z-EdGOcRMJ0Bj zAx&kg Q|fQGGQ3LC|Wg0$j5SomHFis$in_CIux zJYCR_wJ3DEf|(oi7y^{fpb@uY2uy>?Gm&5}MmXfrY@+Io%tsV7kF%92k~H#6)C#h{ z !R4_OcF=UaWu(b=g|uN^EafSu>) z9gAqxlZF@W>1fArGrD4|3*485NuaSwfQ=R3p{adbSG%$HxW{6@B%WZy;Le!lfU}?( zoV$^+;10%%x7m4M4hfLFd2iGrc?Lq(*i6$_d*UR#qS|5QjT_$hOw}HuoPgKjjmh}D zU_@ohlPT`=!E_9uz*=q?RPA2~!hq K5Jb5C3|&V9OlA$R|K@NiUp-(;*nn$R=qemBT;9S7mMbK) zxo^O&Frv`3^W=@^*p=CD?Em(Sx8C>Zblol9;55@tEPu94n3Rh^2tCdBd$&GLyQWcj zYw@b3Tg@DUctyLUU~xc5_dJu+R#8jmSe-`hEJ?vpr&^CiScTF8?+eWYy|MreruAv; z5SvB=<#$m-z6OE`#~)41N@8c$4Q{AMcIh4zhwu9?-CFv_{#zytMUG0Njda|TyOvCq zw=Bmn+JDFKJNke5y6r%?yJh;-d-{x`;C$3 YxzFj=;t&xfc?MJM=B?o4B=0z zdCv2K>f)wNG22YuvxryPYbQf>gn|Sg?wRTljSbh{E-wuExUO~ObmU`eSxL&rHDhh< z(cyfna-Z7uSFcXT#@)*+n(n=}2%sHV+B_)luQhw`5;L)7RZ%&{V`dvcJNXrGIPN>V zv6(v^(&Qg>*A0Wf4T PCMB0GLYI@5k+VjOj z#j1gbp1pKYpi?NI>T`zjezx}cszz_sU_}pVKi}?rO^>L0xJ~(>IOcraNTqsosOO-( z`h3GYpc;@P9#)N=Z(8?OPagIh)}5Yj!H8<6u~d#4X)m@tRcdCbdXHK~FLnY0YUbHh zjytt4b|ZUh7H{?*_t;&0O(d#azN2z75OcAYsZzVD-g`1weX(B{P`hrZaymM8ae(Np z-TcsdI(d3=h$O1p_E0&Sr9~aJtJLjA^qwt>qK^9l>h>~J&R4ZjC&RsU2Nk{Nn|7$v zX`=e0HkFIr7}VLaO8v=D@5Mni>U=w({%lDFbuxy!IO?sxIP67ToB}EeC=5*q2GM|F zTEnm-VYmnw{xFR22u2KrlM2C6 abFTs+pQ23fvUcx^W$#LC zpMmE-thJs#^_e~m?LNf_pHhJR*YIsO`X& >mnNsP8JqQ>y zs#&4*JWOALW=$FVhbdjyerPm==-|9^m@j~hY5Rb$AWb+J?ee0=Fgl(XEIGjOULTB@ z$WuO5kK8m7j3&y91BDQ<5=(&aAbp^~2M2CxXcVe}NQl~Ve*K6yJbDsnq!RoMF>lmh zi%)3;_$8e|tbr03INe!Pz-4;d0IbfY3QI7-A(jQ9%kxWuSRsKD=(m_nhCt-w9dK}% zmI)OE2LLD0<9E{ )ZGD9R;cQl>iRCg?$ zIpxSX%Ro|3@|Eg-J@FxWAnHP4JHQpS`?E%a@f;T+0TEP%AH2|mToRJRo>g(CO4Y;U z$rQ)8QgvaNX9O|GH{ynyyRhRw!@XnFNSXXQ7iU+-U*G>fA7kHukbk^LVfp8qVFX=a z>= ajp^{H{dMUL5-5~eHTzjHc`D?oTyHk@-QY1|#YFz5g3A7PAWTkKhT{GsJ zUzgY&r9G7GPxj^qqeOe_#y`B?{EZ#$w{y^+j< v2k;~O9l zJ>kphm`cFOM*2tsjHPWsfrOAp21-O}U6p`gV|W&VVzeY-Y;jMFRgn-38LuTZt_#ke z?P!2C(mBQVZ|!LRE+FK8(<0?x*$fB$(fI}(Z+4iRubI~!lecOkkfI BC=+w*|XINDdX+ z?_9R9P%w})mGhs$%+g)6zL`e;M1TgA^Sjf=Z_Y&|`sNSkV)B>6@6N^9;|rYMaV}Dt z|Lk1+=~Q9M$^0ci^Dj_lzk{2g0O#W025A0AY*+r^RQVhC*f-_ke`-66`sN<{EoP 8dHK_62pD|?P7Mx z Xkz|~nZ4HV8`YFA`}9q@`0v}!qW-aS@h8me zfA)6eKYVBUWx10554S6*i?eUe#q_LJ3>tn ^?Vc`y~7q^3wSEeeV3tGx+Z{ zqzQm&qcCKa&}~?n@S@LCQf8JoOhblrciT$_m^P*@H*;vc^;XXXrj0;t3H^OI{P!+d zz>)>_Z(FkbqqpM!Em?k5&Hnej6@R`={r#~gPcBMh^ -cG~KEGYN<7v %lx*JbMemdd|R+1Y 6)z#VGsF zq7}sIU+jHmVLIRc!T_(`AE7O(J(vKLa(m+uM0LlL0php5DHkt}w==dbPIkU27k8&r zoG#ABro=BU@EAc@2yrlFAZ+S%3QQF#fi6f4U$tXIXOWi#L4I|>{_U&pAM?Qf$yXox z|GoPDq!atU{GbZaPygY>Zik|wRmtl8p%Z(?9qTd@hyf5LOa2@KV2+T) A?OHnR$mDH2L5&4xX$tlVwE4dkyC#(5w zgs1CqE|-3NeKCA9Mq`5*(1*V WJY=rc&)tBu#s~;CL{LwGkBb3?yKo8Wf_fMC}$O3ju+{6?8%(z+lr6SqKs2gRG~t z4E1GNG`+$M=*&J8)1dH{N3)5c{t-Nmp{6~ZB0#G@mDt)=fh<{_gTEynP8kIA^P_|_ zQm0s{%eu3(9c7SZSi`f7d?nT8T$l9VQE{Bf)UtUxeQmN}3%R6B;%J-`4BHQ-iZUQH z{8yOhbyi#pRHF)O@tDNe=I{b{sbvf+AJ(g!IbbncBDC@MP!zcQ&QTR3iKp5zMrNqU z{dKIf$4Qc)V-Y4K+MaWgDMgkoi7`1A70EUoKmtl6Cgqx#LZ{=s5*^eag%Z9R8jNs1 zUq@9m3L<}_4JZ>KH^!TW_#Mv@ts3uu+w#a7f3oNYLHvF^)mrY|cEhH!y&P)ron{xY zLVoHR8t`S3-VtLk A2Q@1?v(?bZ^8fnKT`C$3zx+Jm&*qcJjAd4AkT>$@)9nhm4c5 zpu`3v9EN`WaXE8GSKl7AdkQ14E4Lx&PhZ_9Af=TzM_nV`<0@uasTEE#vZfT$Q-u;p z23{vQbmetjt!5TkK;Avw%$lhc%x|4@H9E|;A-w|!)uJ2q2@D%@JU^ktvind&{`?_{ z;(FkOW6QB}=^L5#5MnQf5o^Dzb;N|y-RZeEF;8%q