update:merge dev branch

This commit is contained in:
aries-ckt 2023-05-25 21:51:52 +08:00
commit b42e8e4549
47 changed files with 2218 additions and 63 deletions

View File

@ -18,7 +18,7 @@
#** LLM MODELS **# #** LLM MODELS **#
#*******************************************************************# #*******************************************************************#
LLM_MODEL=vicuna-13b LLM_MODEL=vicuna-13b
MODEL_SERVER=http://your_model_server_url MODEL_SERVER=http://120.79.27.110:8000
LIMIT_MODEL_CONCURRENCY=5 LIMIT_MODEL_CONCURRENCY=5
MAX_POSITION_EMBEDDINGS=4096 MAX_POSITION_EMBEDDINGS=4096

2
.gitignore vendored
View File

@ -6,6 +6,8 @@ __pycache__/
# C extensions # C extensions
*.so *.so
message/
.env .env
.idea .idea
.vscode .vscode

View File

@ -44,6 +44,11 @@ Run on an RTX 4090 GPU. [YouTube](https://www.youtube.com/watch?v=1PWI6F89LPo)
<img src="./assets/demo_en.gif" width="600px" /> <img src="./assets/demo_en.gif" width="600px" />
</p> </p>
### Run Plugin
<p align="center">
<img src="./assets/auto_sql_en.gif" width="600px" />
</p>
### SQL Generation ### SQL Generation
1. Generate Create Table SQL 1. Generate Create Table SQL
@ -185,7 +190,7 @@ We provide a user interface for Gradio, which allows you to use DB-GPT through o
To use multiple models, modify the LLM_MODEL parameter in the .env configuration file to switch between the models. To use multiple models, modify the LLM_MODEL parameter in the .env configuration file to switch between the models.
####Create your own knowledge repository: ### Create your own knowledge repository:
1.Place personal knowledge files or folders in the pilot/datasets directory. 1.Place personal knowledge files or folders in the pilot/datasets directory.
@ -213,7 +218,7 @@ Run the Python interpreter and type the commands:
## Acknowledgement ## Acknowledgement
The achievements of this project are thanks to the technical community, especially the following projects: This project is standing on the shoulders of giants and is not going to work without the open-source communities. Special thanks to the following projects for their excellent contribution to the AI industry:
- [FastChat](https://github.com/lm-sys/FastChat) for providing chat services - [FastChat](https://github.com/lm-sys/FastChat) for providing chat services
- [vicuna-13b](https://lmsys.org/blog/2023-03-30-vicuna/) as the base model - [vicuna-13b](https://lmsys.org/blog/2023-03-30-vicuna/) as the base model
- [langchain](https://langchain.readthedocs.io/) tool chain - [langchain](https://langchain.readthedocs.io/) tool chain
@ -245,4 +250,4 @@ This project follows the git-contributor [spec](https://github.com/xudafeng/git-
The MIT License (MIT) The MIT License (MIT)
## Contact Information ## Contact Information
We are working on building a community, if you have any ideas about building the community, feel free to contact us. [Discord](https://discord.com/invite/twmZk3vv) We are working on building a community, if you have any ideas about building the community, feel free to contact us. [Discord](https://discord.gg/kMFf77FH)

View File

@ -39,8 +39,9 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
<img src="./assets/演示.gif" width="600px" /> <img src="./assets/演示.gif" width="600px" />
</p> </p>
### SQL 插件化执行
<p align="center"> <p align="center">
<img src="./assets/Auto-DB-GPT.gif" width="600px" /> <img src="./assets/auto_sql.gif" width="600px" />
</p> </p>
### SQL 生成 ### SQL 生成

BIN
assets/auto_sql.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 MiB

BIN
assets/auto_sql_en.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 MiB

View File

@ -7,16 +7,14 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import toml import toml
import os
import sys
project = "DB-GPT" project = "DB-GPT"
copyright = "2023, csunny" copyright = "2023, csunny"
author = "csunny" author = "csunny"
with open("../pyproject.toml") as f: version = "0.1.0"
data = toml.load(f)
version = data["tool"]["poetry"]["version"]
release = version
html_title = project + " " + version html_title = project + " " + version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -89,7 +89,7 @@ Use Cases
| Best Practices and built-in implementations for common DB-GPT use cases: | Best Practices and built-in implementations for common DB-GPT use cases:
- `Sql generation and diagnosis <./use_cases/sql_generation_and_diagnosis.html>`: SQL generation and diagnosis. - `Sql generation and diagnosis <./use_cases/sql_generation_and_diagnosis.html>`_: SQL generation and diagnosis.
- `knownledge Based QA <./use_cases/knownledge_based_qa.html>`_: A important scene for user to chat with database documents, codes, bugs and schemas. - `knownledge Based QA <./use_cases/knownledge_based_qa.html>`_: A important scene for user to chat with database documents, codes, bugs and schemas.

View File

@ -0,0 +1,38 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()

View File

@ -0,0 +1,57 @@
import markdown2
import pandas as pd
def datas_to_table_html(data):
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
return html.replace("\n", " ")
def generate_markdown_table(data):
"""\n 生成 Markdown 表格\n data: 一个包含表头和表格内容的二维列表\n"""
# 获取表格列数
num_cols = len(data[0])
# 生成表头
header = "| "
for i in range(num_cols):
header += data[0][i] + " | "
# 生成分隔线
separator = "| "
for i in range(num_cols):
separator += "--- | "
# 生成表格内容
content = ""
for row in data[1:]:
content += "| "
for i in range(num_cols):
content += str(row[i]) + " | "
content += "\n"
# 合并表头、分隔线和表格内容
table = header + "\n" + separator + "\n" + content
return table
def generate_htm_table(data):
markdown_text = generate_markdown_table(data)
html_table = markdown2.markdown(markdown_text, extras=["tables"])
return html_table
if __name__ == "__main__":
# mk_text = "| user_name | phone | email | city | create_time | last_login_time | \n| --- | --- | --- | --- | --- | --- | \n| zhangsan | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| hanmeimei | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| wangwu | 123 | None | 上海 | 2023-05-13 09:09:09 | None | \n| test1 | 123 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test2 | 123 | None | 成都 | 2023-05-11 09:09:09 | None | \n| test3 | 23 | None | 成都 | 2023-05-12 09:09:09 | None | \n| test4 | 23 | None | 成都 | 2023-05-09 09:09:09 | None | \n| test5 | 123 | None | 上海 | 2023-05-08 09:09:09 | None | \n| test6 | 123 | None | 成都 | 2023-05-08 09:09:09 | None | \n| test7 | 23 | None | 上海 | 2023-05-10 09:09:09 | None |\n"
# print(generate_htm_table(mk_text))
table_style = """<style>\n table {\n border-collapse: collapse;\n width: 100%;\n }\n th, td {\n border: 1px solid #ddd;\n padding: 8px;\n text-align: center;\n line-height: 150px; \n }\n th {\n background-color: #f2f2f2;\n color: #333;\n font-weight: bold;\n }\n tr:nth-child(even) {\n background-color: #f9f9f9;\n }\n tr:hover {\n background-color: #f2f2f2;\n }\n </style>"""
print(table_style.replace("\n", " "))

9
pilot/common/schema.py Normal file
View File

@ -0,0 +1,9 @@
from enum import auto, Enum
from typing import List, Any
class SeparatorStyle(Enum):
SINGLE = "###"
TWO = "</s>"
THREE = auto()
FOUR = auto()

View File

@ -0,0 +1,316 @@
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
import sqlalchemy
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import CursorResult, Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
f' Columns: {str(index["column_names"])}'
)
class Database:
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
self._db_sessions = Session
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
# self._all_tables = set(
# self._inspector.get_table_names(schema=schema)
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
# )
# self._include_tables = set(include_tables) if include_tables else set()
# if self._include_tables:
# missing_tables = self._include_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"include_tables {missing_tables} not found in database"
# )
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
# if self._ignore_tables:
# missing_tables = self._ignore_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"ignore_tables {missing_tables} not found in database"
# )
# usable_tables = self.get_usable_table_names()
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
# if not isinstance(sample_rows_in_table_info, int):
# raise TypeError("sample_rows_in_table_info must be an integer")
#
# self._sample_rows_in_table_info = sample_rows_in_table_info
# self._indexes_in_table_info = indexes_in_table_info
#
# self._custom_table_info = custom_table_info
# if self._custom_table_info:
# if not isinstance(self._custom_table_info, dict):
# raise TypeError(
# "table_info must be a dictionary with table names as keys and the "
# "desired table info as values"
# )
# # only keep the tables that are also present in the database
# intersection = set(self._custom_table_info).intersection(self._all_tables)
# self._custom_table_info = dict(
# (table, self._custom_table_info[table])
# for table in self._custom_table_info
# if table in intersection
# )
# self._metadata = metadata or MetaData()
# # # including view support if view_support = true
# self._metadata.reflect(
# views=view_support,
# bind=self._engine,
# only=list(self._usable_tables),
# schema=self._schema,
# )
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> Database:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
def get_session(self, db_name: str):
session = self._db_sessions()
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
self._metadata.reflect(bind=self._engine, schema=db_name)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
+ (
self._inspector.get_view_names(schema=db_name)
if self.view_support
else []
)
)
return session
def get_current_db_name(self, session) -> str:
return session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
"""
cursor = session.execute(text(_sql))
results = cursor.fetchall()
return results
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info:
table_info += f"\n{self._get_sample_rows(table)}\n"
if has_extra_info:
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result: CursorResult = connection.execute(command)
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _get_table_indexes(self, table: Table) -> str:
indexes = self._inspector.get_indexes(table.name)
indexes_formatted = "\n".join(map(_format_index, indexes))
return f"Table Indexes:\n{indexes_formatted}"
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
cursor = session.execute(text(command))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
return result
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(session, command, fetch)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]

View File

@ -8,6 +8,7 @@ import nltk
from auto_gpt_plugin_template import AutoGPTPluginTemplate from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton from pilot.singleton import Singleton
from pilot.common.sql_database import Database
class Config(metaclass=Singleton): class Config(metaclass=Singleton):
@ -39,6 +40,13 @@ class Config(metaclass=Singleton):
self.use_mac_os_tts = False self.use_mac_os_tts = False
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS") self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS")
# milvus or zilliz cloud configuration
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
self.milvus_username = os.getenv("MILVUS_USERNAME")
self.milvus_password = os.getenv("MILVUS_PASSWORD")
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n") self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = os.getenv("IMAGE_PROVIDER", True) self.image_provider = os.getenv("IMAGE_PROVIDER", True)
@ -55,6 +63,7 @@ class Config(metaclass=Singleton):
) )
self.speak_mode = False self.speak_mode = False
self.prompt_templates = {}
### Related configuration of built-in commands ### Related configuration of built-in commands
self.command_registry = [] self.command_registry = []
@ -67,6 +76,8 @@ class Config(metaclass=Singleton):
self.execute_local_commands = ( self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
) )
### message stor file
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
### The associated configuration parameters of the plug-in control the loading and use of the plug-in ### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
@ -91,6 +102,19 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### TODO Adapt to multiple types of libraries
self.local_db = Database.from_uri(
"mysql+pymysql://"
+ self.LOCAL_DB_USER
+ ":"
+ self.LOCAL_DB_PASSWORD
+ "@"
+ self.LOCAL_DB_HOST
+ ":"
+ str(self.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
### LLM Model Service Configuration ### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))

View File

@ -32,6 +32,18 @@ class MySQLOperator:
results = cursor.fetchall() results = cursor.fetchall()
return results return results
def run_sql(self, db_name: str, sql: str, fetch: str = "all"):
with self.conn.cursor() as cursor:
cursor.execute("USE " + db_name)
cursor.execute(sql)
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return str(result)
def get_index(self, schema_name): def get_index(self, schema_name):
pass pass

View File

@ -2,8 +2,9 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import dataclasses import dataclasses
from enum import Enum, auto import uuid
from typing import Any, List from enum import auto, Enum
from typing import List, Any
from pilot.configs.config import Config from pilot.configs.config import Config
@ -42,6 +43,7 @@ class Conversation:
# Used for gradio server # Used for gradio server
skip_next: bool = False skip_next: bool = False
conv_id: Any = None conv_id: Any = None
last_user_input: Any = None
def get_prompt(self): def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE: if self.sep_style == SeparatorStyle.SINGLE:
@ -269,6 +271,7 @@ conversation_types = {
"native": "LLM原生对话", "native": "LLM原生对话",
"default_knownledge": "默认知识库对话", "default_knownledge": "默认知识库对话",
"custome": "新增知识库对话", "custome": "新增知识库对话",
"auto_execute_plugin": "对话使用插件",
} }
conv_templates = { conv_templates = {

View File

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pydantic import BaseModel, Extra, Field, root_validator
class BaseOutputParser(BaseModel, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
@abstractmethod
def parse(self, text: str) -> T:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of language model )
and parses it into some structure.
Args:
text: output of language model
Returns:
structured output
"""
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Optional method to parse the output of an LLM call with a prompt.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: output of language model
prompt: prompt value
Returns:
structured output
"""
return self.parse(completion)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class OutputParserException(Exception):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
"""
pass

0
pilot/memory/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,34 @@
from __future__ import annotations
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from pilot.scene.message import OnceConversation
class BaseChatHistoryMemory(ABC):
def __init__(self):
self.conversations: List[OnceConversation] = []
@abstractmethod
def messages(self) -> List[OnceConversation]: # type: ignore
"""Retrieve the messages from the local file"""
@abstractmethod
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""
@abstractmethod
def clear(self) -> None:
"""Clear session memory from the local file"""

View File

@ -0,0 +1,49 @@
from typing import List
import json
import os
import datetime
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pathlib import Path
from pilot.configs.config import Config
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
conversations_to_dict,
)
CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
date_string = now.strftime("%Y%m%d")
path: str = f"{CFG.message_dir}/{date_string}"
os.makedirs(path, exist_ok=True)
dir_path = Path(path)
self.file_path = Path(dir_path / f"{chat_session_id}.json")
if not self.file_path.exists():
self.file_path.touch()
self.file_path.write_text(json.dumps([]))
def messages(self) -> List[OnceConversation]:
items = json.loads(self.file_path.read_text())
history: List[OnceConversation] = []
for onece in items:
messages = conversation_from_dict(onece)
history.append(messages)
return history
def append(self, once_message: OnceConversation) -> None:
historys = self.messages()
historys.append(once_message)
self.file_path.write_text(
json.dumps(conversations_to_dict(historys), ensure_ascii=False, indent=4),
encoding="UTF-8",
)
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))

View File

117
pilot/out_parser/base.py Normal file
View File

@ -0,0 +1,117 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.prompts.base import PromptValue
T = TypeVar("T")
class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
def __init__(self, sep: str, is_stream_out: bool):
self.sep = sep
self.is_stream_out = is_stream_out
# TODO 后续和模型绑定
def _parse_model_stream_resp(self, response, sep: str):
pass
def _parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip()
text = text.rstrip()
text = text.lower()
respObj = json.loads(text)
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
if respObj_ex["error_code"] == 0:
all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(sep)
last_index = -1
for i in range(len(tmpResp)):
if tmpResp[i].find("assistant:") != -1:
last_index = i
ai_response = tmpResp[last_index]
ai_response = ai_response.replace("assistant:", "")
ai_response = ai_response.replace("\n", "")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
print("un_stream clear response:{}", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_model_server_out(self, response) -> str:
"""
parse the model server http response
Args:
response:
Returns:
"""
if not self.is_stream_out:
return self._parse_model_nostream_resp(response, self.sep)
else:
return self._parse_model_stream_resp(response, self.sep)
def parse_prompt_response(self, model_out_text) -> T:
"""
parse model out text to prompt define response
Args:
model_out_text:
Returns:
"""
pass
def parse_view_response(self, ai_text) -> str:
"""
parse the ai response info to user view
Args:
text:
Returns:
"""
pass
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict

49
pilot/prompts/base.py Normal file
View File

@ -0,0 +1,49 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
class ChatPromptValue(PromptValue):
messages: List[BaseMessage]
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return self.messages

View File

View File

@ -0,0 +1,52 @@
from typing import Any, Callable, Dict, List, Optional
class PromptGenerator:
"""
generating custom prompt strings based on constraints
Compatible with AutoGpt Plugin;
"""
def __init__(self) -> None:
"""
Initialize the PromptGenerator object with empty lists of constraints,
commands, resources, and performance evaluations.
"""
self.constraints = []
self.commands = []
self.resources = []
self.performance_evaluation = []
self.goals = []
self.command_registry = None
self.name = "Bob"
self.role = "AI"
self.response_format = None
def add_command(
self,
command_label: str,
command_name: str,
args=None,
function: Optional[Callable] = None,
) -> None:
"""
Add a command to the commands list with a label, name, and optional arguments.
GB-GPT and Auto-GPT plugin registration command.
Args:
command_label (str): The label of the command.
command_name (str): The name of the command.
args (dict, optional): A dictionary containing argument names and their
values. Defaults to None.
function (callable, optional): A callable function to be called when
the command is executed. Defaults to None.
"""
if args is None:
args = {}
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
command = {
"label": command_label,
"name": command_name,
"args": command_args,
"function": function,
}
self.commands.append(command)

117
pilot/prompts/prompt_new.py Normal file
View File

@ -0,0 +1,117 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.common.formatting import formatter
from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:
from jinja2 import Template
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
return Template(template).render(**kwargs)
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
class PromptTemplate(BaseModel, ABC):
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template_scene: str
template_define: str
"""this template define"""
template: str
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: str
"""default use stream out"""
stream_out: bool = True
""""""
output_parser: BaseOutputParser = None
""""""
sep: str = SeparatorStyle.SINGLE.value
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "prompt"
def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""
Generate a formatted string representation of a command.
Args:
command (dict): A dictionary containing command information.
Returns:
str: The formatted command string.
"""
args_string = ", ".join(
f'"{key}": "{value}"' for key, value in command["args"].items()
)
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = []
if self.command_registry:
command_strings += [
str(item)
for item in self.command_registry.commands.values()
if item.enabled
]
# terminate command is added manually
command_strings += [self._generate_command_string(item) for item in items]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
def add_goals(self, goal: str) -> None:
self.goals.append(goal)
def add_constraint(self, constraint: str) -> None:
"""
Add a constraint to the constraints list.
Args:
constraint (str): The constraint to be added.
"""
self.constraints.append(constraint)

View File

@ -0,0 +1,363 @@
from __future__ import annotations
import json
import yaml
from string import Formatter
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.out_parser.base import BaseOutputParser
from pilot.prompts.base import PromptValue
from pilot.scene.base_message import HumanMessage, AIMessage, SystemMessage, BaseMessage
from pilot.common.formatting import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:
from jinja2 import Template
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
return Template(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
error_message = ""
if missing_variables:
error_message += f"Missing variables: {missing_variables} "
if extra_variables:
error_message += f"Extra variables: {extra_variables}"
if error_message:
raise KeyError(error_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""Check that template string is valid."""
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables)
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
)
class BasePromptTemplate(BaseModel, ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
prompt_dict["_type"] = self._prompt_type
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class StringPromptValue(PromptValue):
text: str
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt should expose the format method, returning a prompt."""
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))
class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM.
Example:
.. code-block:: python
from langchain import PromptTemplate
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
"""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template: str
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
"""Whether or not to try validating the template."""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "prompt"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
return values
@classmethod
def from_examples(
cls,
examples: List[str],
suffix: str,
input_variables: List[str],
example_separator: str = "\n\n",
prefix: str = "",
**kwargs: Any,
) -> PromptTemplate:
"""Take examples in list format with prefix and suffix to create a prompt.
Intended to be used as a way to dynamically create a prompt from examples.
Args:
examples: List of examples to use in the prompt.
suffix: String to go after the list of examples. Should generally
set up the user's input.
input_variables: A list of variable names the final prompt template
will expect.
example_separator: The separator to use in between examples. Defaults
to two new line characters.
prefix: String that should go before any examples. Generally includes
examples. Default to an empty string.
Returns:
The final prompt generated.
"""
template = example_separator.join([prefix, *examples, suffix])
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_file(
cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
) -> PromptTemplate:
"""Load a prompt from a file.
Args:
template_file: The path to the file containing the prompt template.
input_variables: A list of variable names the final prompt template
will expect.
Returns:
The prompt loaded from the file.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
else:
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
# For backwards compatibility.
Prompt = PromptTemplate

0
pilot/scene/__init__.py Normal file
View File

9
pilot/scene/base.py Normal file
View File

@ -0,0 +1,9 @@
from enum import Enum
class ChatScene(Enum):
ChatWithDb = "chat_with_db"
ChatExecution = "chat_execution"
ChatKnowledge = "chat_default_knowledge"
ChatNewKnowledge = "chat_new_knowledge"
ChatNormal = "chat_normal"

104
pilot/scene/base_chat.py Normal file
View File

@ -0,0 +1,104 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field, root_validator, validator, Extra
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
import requests
from urllib.parse import urljoin
import pilot.configs.config
from pilot.scene.message import OnceConversation
from pilot.prompts.prompt_new import PromptTemplate
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.common.schema import SeparatorStyle
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}
CFG = Config()
class BaseChat(ABC):
chat_scene: str = None
llm_model: Any = None
temperature: float = 0.6
max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 2
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(self, chat_mode, chat_session_id, current_user_input):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
### TODO
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0
### load chat_session_id's chat historys
self._load_history(self.chat_session_id)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def chat_type(self) -> str:
raise NotImplementedError("Not supported for this chat type.")
def call(self):
pass
def chat_show(self):
pass
def current_ai_response(self) -> str:
pass
def _load_history(self, session_id: str) -> List[OnceConversation]:
"""
load chat history by session_id
Args:
session_id:
Returns:
"""
return self.memory.messages()
def generate(self, p) -> str:
"""
generate context for LLM input
Args:
p:
Returns:
"""
pass

148
pilot/scene/base_message.py Normal file
View File

@ -0,0 +1,148 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from pydantic import BaseModel, Extra, Field, root_validator
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
class BaseMessage(BaseModel):
"""Message object."""
content: str
additional_kwargs: dict = Field(default_factory=dict)
@property
@abstractmethod
def type(self) -> str:
"""Type of the message, used for serialization."""
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class ViewMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "view"
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class Generation(BaseModel):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
class ChatGeneration(Generation):
"""Output of a single generation."""
text = ""
message: BaseMessage
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["text"] = values["message"].content
return values
class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""
generations: List[ChatGeneration]
"""List of the things generated."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "view":
return ViewMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]

View File

279
pilot/scene/chat_db/chat.py Normal file
View File

@ -0,0 +1,279 @@
import requests
import datetime
import threading
import json
import traceback
from urllib.parse import urljoin
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from typing import Any, Iterable, List, Optional
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.scene.chat_db.out_parser import SqlAction
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_db.prompt import chat_db_prompt
from pilot.out_parser.base import BaseOutputParser
from pilot.scene.chat_db.out_parser import DbChatOutputParser
CFG = Config()
class ChatWithDb(BaseChat):
chat_scene: str = ChatScene.ChatWithDb.value
"""Number of results to return from the query"""
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def call(self) -> str:
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
# "stop": self.sep_style,
}
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now()
# TODO
self.current_message.tokens = 0
current_prompt = self.prompt_template.format(**input_values)
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
self.current_message.add_system_message(current_prompt)
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
"temperature": float(self.temperature),
"max_new_tokens": int(self.max_new_tokens),
"stop": self.prompt_template.sep,
}
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
### 走非流式的模型服务接口
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
ai_response_text = (
self.prompt_template.output_parser.parse_model_server_out(response)
)
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.database.run(self.db_connect, prompt_define_response.sql)
if hasattr(prompt_define_response, "thoughts"):
if prompt_define_response.thoughts.get("speak"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("speak"), result
)
)
elif prompt_define_response.thoughts.get("reasoning"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("reasoning"), result
)
)
else:
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts, result
)
)
else:
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response, result
)
)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
def chat_show(self):
ret = []
# 单论对话只能有一次User 记录 和一次 AI 记录
# TODO 推理过程前端展示。。。
for message in self.current_message.messages:
if isinstance(message, HumanMessage):
ret[-1][-2] = message.content
# 是否展示推理过程
if isinstance(message, ViewMessage):
ret[-1][-1] = message.content
return ret
# 暂时为了兼容前端
def current_ai_response(self) -> str:
for message in self.current_message.messages:
if message.type == "view":
return message.content
return None
def generate_llm_text(self) -> str:
text = self.prompt_template.template_define + self.prompt_template.sep
### 线处理历史信息
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage):
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
if not isinstance(message, ViewMessage):
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return text
@property
def chat_type(self) -> str:
return ChatScene.ChatExecution.value
if __name__ == "__main__":
# chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息")
#
# chat.call()
#
# resp = chat.chat_show()
#
# print(vars(resp))
# memory = FileHistoryMemory("test123")
# once1 = OnceConversation()
# once1.add_user_message("问题测试")
# once1.add_system_message("prompt1")
# once1.add_system_message("prompt2")
# once1.chat_order = 1
# once1.set_start_time(datetime.datetime.now())
# memory.append(once1)
#
# once = OnceConversation()
# once.add_user_message("问题测试2")
# once.add_system_message("prompt3")
# once.add_system_message("prompt4")
# once.chat_order = 2
# once.set_start_time(datetime.datetime.now())
# memory.append(once)
db: Database = CFG.local_db
db_connect = db.get_session("gpt-user")
data = db.run(db_connect, "select * from users")
print(generate_htm_table(data))
#
# print(db.run(db_connect, "select * from users"))
#
# #
# # def print_numbers():
# # db_connect1 = db.get_session("dbgpt-test")
# # cursor1 = db_connect1.execute(text("select * from test_name"))
# # if cursor1.returns_rows:
# # result1 = cursor1.fetchall()
# # print( result1)
# #
# #
# # # 创建线程
# # t = threading.Thread(target=print_numbers)
# # # 启动线程
# # t.start()
#
# print(db.run(db_connect, "select * from tran_order"))
#
# print(db.run(db_connect, "select count(*) as aa from tran_order"))
#
# print(db.table_simple_info(db_connect))
# my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# index = 3
# last_three_elements = my_list[-index:]
# print(last_three_elements)

View File

View File

@ -0,0 +1,65 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
class SqlAction(NamedTuple):
sql: str
thoughts: Dict
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_model_server_out(self, response) -> str:
return super().parse_model_server_out(response)
def parse_prompt_response(self, model_out_text):
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing")
json_pattern = r"{(.+?)}"
m = re.search(json_pattern, cleaned_output)
if m:
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
response = json.loads(cleaned_output)
sql, thoughts = response["sql"], response["thoughts"]
return SqlAction(sql, thoughts)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text
@property
def _type(self) -> str:
return "sql_chat"

View File

@ -0,0 +1,67 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
PROMPT_SUFFIX = """Only use the following tables:
{table_info}
Question: {input}
"""
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = {
"thoughts": {
"reasoning": "reasoning",
"speak": "thoughts summary to say to user",
},
"sql": "SQL Query to run",
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
chat_db_prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDb.value,
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})

View File

View File

@ -0,0 +1,28 @@
from typing import List
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.message import OnceConversation
from pilot.scene.base import ChatScene
class ChatWithPlugin(BaseChat):
chat_scene: str = ChatScene.ChatExecution.value
def __init__(self, chat_mode, chat_session_id, current_user_input):
super().__init__(chat_mode, chat_session_id, current_user_input)
def call(self):
super().call()
def chat_show(self):
super().chat_show()
def _load_history(self, session_id: str) -> List[OnceConversation]:
return super()._load_history(session_id)
def generate(self, p) -> str:
return super().generate(p)
@property
def chat_type(self) -> str:
return ChatScene.ChatExecution.value

View File

@ -0,0 +1,17 @@
from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton
from pilot.scene.chat_db.chat import ChatWithDb
from pilot.scene.chat_execution.chat import ChatWithPlugin
class ChatFactory(metaclass=Singleton):
@staticmethod
def get_implementation(chat_mode, **kwargs):
chat_classes = BaseChat.__subclasses__()
implementation = None
for cls in chat_classes:
if cls.chat_scene == chat_mode:
implementation = cls(**kwargs)
if implementation == None:
raise Exception("Invalid implementation name:" + chat_mode)
return implementation

View File

View File

View File

@ -0,0 +1,31 @@
import builtins
def stream_write_and_read(lst):
# 对lst使用yield from进行可迭代对象的扁平化
yield from lst
while True:
val = yield
lst.append(val)
if __name__ == "__main__":
# 创建一个空列表
my_list = []
# 使用生成器写入数据
stream_writer = stream_write_and_read(my_list)
next(stream_writer)
stream_writer.send(10)
print(1)
stream_writer.send(20)
print(2)
stream_writer.send(30)
print(3)
# 使用生成器读取数据
stream_reader = stream_write_and_read(my_list)
next(stream_reader)
print(stream_reader.send(None))
print(stream_reader.send(None))
print(stream_reader.send(None))

101
pilot/scene/message.py Normal file
View File

@ -0,0 +1,101 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pydantic import BaseModel, Field, root_validator, validator
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
)
from pilot.scene.base_message import (
BaseMessage,
AIMessage,
HumanMessage,
SystemMessage,
ViewMessage,
messages_to_dict,
messages_from_dict,
)
class OnceConversation:
"""
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
"""
def __init__(self):
self.messages: List[BaseMessage] = []
self.start_date: str = ""
self.chat_order: int = 0
self.cost: int = 0
self.tokens: int = 0
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
has_message = any(
isinstance(instance, HumanMessage) for instance in self.messages
)
if has_message:
raise ValueError("Already Have Human message")
self.messages.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message:
raise ValueError("Already Have Ai message")
self.messages.append(AIMessage(content=message))
""" """
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
self.messages.append(ViewMessage(content=message))
""" """
def add_system_message(self, message: str) -> None:
"""Add an AI message to the store"""
self.messages.append(SystemMessage(content=message))
def set_start_time(self, datatime: datetime):
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
self.start_date = dt_str
def clear(self) -> None:
"""Remove all messages from the store"""
self.messages.clear()
self.session_id = None
def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = ""
if once.start_date:
if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else:
start_str = once.start_date
return {
"chat_order": once.chat_order,
"start_date": start_str,
"cost": once.cost if once.cost else 0,
"tokens": once.tokens if once.tokens else 0,
"messages": messages_to_dict(once.messages),
}
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
return [_conversation_to_dic(m) for m in conversations]
def conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation()
conversation.cost = once.get("cost", 0)
conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order"))
print(once.get("messages"))
conversation.messages = messages_from_dict(once.get("messages", []))
return conversation

View File

@ -1,49 +0,0 @@
[tool.poetry]
name = "db-gpt"
version = "0.0.6"
description = "Interact with your data and environment privately"
authors = []
readme = "README.md"
license = "MIT"
packages = [{include = "db_gpt"}]
repository = "https://www.github.com/csunny/DB-GPT"
[tool.poetry.dependencies]
python = "^3.10"
accelerate = "^0.16"
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"
myst_parser = "^0.18.1"
nbsphinx = "^0.8.9"
sphinx = "^4.5.0"
sphinx-autobuild = "^2021.3.14"
sphinx_book_theme = "^0.3.3"
sphinx_rtd_theme = "^1.0.0"
sphinx-typlog-theme = "^0.8.0"
sphinx-panels = "^0.6.0"
toml = "^0.10.2"
myst-nb = "^0.17.1"
linkchecker = "^10.2.1"
sphinx-copybutton = "^0.5.1"
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.7.0"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -46,7 +46,7 @@ wandb
llama-index==0.5.27 llama-index==0.5.27
pymysql pymysql
unstructured==0.6.3 unstructured==0.6.3
grpcio==1.47.5 grpcio==1.54.2
auto-gpt-plugin-template auto-gpt-plugin-template
pymdown-extensions pymdown-extensions

38
setup.py Normal file
View File

@ -0,0 +1,38 @@
from typing import List
import setuptools
from setuptools import find_packages
with open("README.md", "r") as fh:
long_description = fh.read()
def parse_requirements(file_name: str) -> List[str]:
with open(file_name) as f:
return [
require.strip()
for require in f
if require.strip() and not require.startswith("#")
]
setuptools.setup(
name="DB-GPT",
packages=find_packages(),
version="0.1.0",
author="csunny",
author_email="cfqcsunny@gmail.com",
description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment."
" With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.",
long_description=long_description,
long_description_content_type="text/markdown",
install_requires=parse_requirements("requirements.txt"),
url="https://github.com/csunny/DB-GPT",
license="https://opensource.org/license/mit/",
python_requires=">=3.10",
entry_points={
"console_scripts": [
"dbgpt_server=pilot.server:webserver",
],
},
)