add llm support for gpt4all (#209)

add llm support for gpt4all #178 #138
This commit is contained in:
magic.chen 2023-06-14 11:35:42 +08:00 committed by GitHub
commit e3f91e98b4
29 changed files with 218 additions and 120 deletions

View File

@ -77,19 +77,23 @@ def load_native_plugins(cfg: Config):
print("load_native_plugins") print("load_native_plugins")
### TODO 默认拉主分支,后续拉发布版本 ### TODO 默认拉主分支,后续拉发布版本
branch_name = cfg.plugins_git_branch branch_name = cfg.plugins_git_branch
native_plugin_repo ="DB-GPT-Plugins" native_plugin_repo = "DB-GPT-Plugins"
url = "https://github.com/csunny/{repo}/archive/{branch}.zip" url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name), response = requests.get(
headers={'Authorization': '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'}) url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200: if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*"))
for file in files: for file in files:
os.remove(file) os.remove(file)
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S') time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" file_name = (
f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
)
print(file_name) print(file_name)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
f.write(response.content) f.write(response.content)

View File

@ -66,7 +66,6 @@ class Database:
self._sample_rows_in_table_info = set() self._sample_rows_in_table_info = set()
self._indexes_in_table_info = indexes_in_table_info self._indexes_in_table_info = indexes_in_table_info
@classmethod @classmethod
def from_uri( def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
@ -399,7 +398,6 @@ class Database:
ans = cursor.fetchall() ans = cursor.fetchall()
return ans[0][1] return ans[0][1]
def get_fields(self, table_name): def get_fields(self, table_name):
"""Get column fields about specified table.""" """Get column fields about specified table."""
session = self._db_sessions() session = self._db_sessions()

View File

@ -14,8 +14,8 @@ LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data") DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts") FONT_DIR = os.path.join(PILOT_PATH, "fonts")
current_directory = os.getcwd() current_directory = os.getcwd()
@ -43,6 +43,7 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
class ClickHouseConnector(RDBMSDatabase): class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector""" """ClickHouseConnector"""
@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_config(cls) -> RDBMSDatabase:
""" """
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
CFG.LOCAL_DB_PATH, cls,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, db_path: str, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
class DuckDbConnect(RDBMSDatabase): class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData """Connect Duckdb Database fetch MetaData
Args: Args:
@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_config(cls) -> RDBMSDatabase:
""" """
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
CFG.LOCAL_DB_PATH, cls,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, db_path: str, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase): class MSSQLConnect(RDBMSDatabase):
"""Connect MSSQL Database fetch MetaData """Connect MSSQL Database fetch MetaData
Args: Args:
@ -18,6 +15,4 @@ class MSSQLConnect(RDBMSDatabase):
dialect: str = "mssql" dialect: str = "mssql"
driver: str = "pyodbc" driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb","modeldb", "resource"] default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MySQLConnect(RDBMSDatabase): class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData """Connect MySQL Database fetch MetaData
Args: Args:
@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase):
driver: str = "pymysql" driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -2,8 +2,10 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase): class OracleConnector(RDBMSDatabase):
"""OracleConnector""" """OracleConnector"""
type: str = "ORACLE" type: str = "ORACLE"
driver: str = "oracle" driver: str = "oracle"

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase): class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector""" """PostgresConnector is a class which Connector"""

View File

@ -57,18 +57,19 @@ CFG = Config()
if __name__ == "__main__": if __name__ == "__main__":
def __extract_json(s):
i = s.index('{')
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == '}':
count -= 1
elif c == '{':
count += 1
if count == 0:
break
assert (count == 0) # 检查是否找到最后一个'}'
return s[i:j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" def __extract_json(s):
print(__extract_json(ss)) i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss))

View File

@ -35,13 +35,12 @@ class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
def __init__( def __init__(
self, self,
engine, engine,
schema: Optional[str] = None, schema: Optional[str] = None,
metadata: Optional[MetaData] = None, metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None, ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
): ):
"""Create engine from database URI.""" """Create engine from database URI."""
self._engine = engine self._engine = engine
@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect):
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
CFG.LOCAL_DB_HOST, cls,
CFG.LOCAL_DB_PORT, CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_PASSWORD, CFG.LOCAL_DB_USER,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls,
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str( host: str,
CFG.LOCAL_DB_PORT) port: int,
user: str,
pwd: str,
db_name: str = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.connect_driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
if cls.dialect: if cls.dialect:
db_url = cls.dialect + "+" + db_url db_url = cls.dialect + "+" + db_url
if db_name: if db_name:
@ -81,7 +99,7 @@ class RDBMSDatabase(BaseConnect):
@classmethod @classmethod
def from_uri( def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase: ) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI.""" """Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {} _engine_args = engine_args or {}
@ -167,7 +185,7 @@ class RDBMSDatabase(BaseConnect):
tbl tbl
for tbl in self._metadata.sorted_tables for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names) if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
] ]
tables = [] tables = []
@ -180,7 +198,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine)) create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}" table_info = f"{create_table.rstrip()}"
has_extra_info = ( has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info self._indexes_in_table_info or self._sample_rows_in_table_info
) )
if has_extra_info: if has_extra_info:
table_info += "\n\n/*" table_info += "\n\n/*"

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch import torch
import os
from typing import List from typing import List
from functools import cache from functools import cache
from transformers import ( from transformers import (
@ -183,18 +184,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
class GPT4AllAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper):
"""A light version for someone who want practise LLM use laptop.""" """
A light version for someone who want practise LLM use laptop.
All model names see: https://gpt4all.io/models/models.json
"""
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO import gpt4all
pass
if model_path is None and from_pretrained_kwargs.get("model_name") is None:
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
else:
path, file = os.path.split(model_path)
model = gpt4all.GPT4All(model_path=path, model_name=file)
return model, None
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
@ -209,6 +218,7 @@ register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later # just for test_py, remove this later

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":")
print(f"gpt4all, role: {role}, query: {query}")
def worker():
model.generate(prompt=query, streaming=True)
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.01)
t.join()

View File

@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
} }
) )
# Move the last user's information to the end # Move the last user's information to the end
temp_his = history[::-1] temp_his = history[::-1]
last_user_input = None last_user_input = None
for m in temp_his: for m in temp_his:
@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
text = "" text = ""
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
json_data = line.split(b': ', 1)[1] json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8") decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower(): if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data) obj = json.loads(json_data)

View File

@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
@ -121,17 +121,17 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index('{') i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == '}': if c == "}":
count -= 1 count -= 1
elif c == '{': elif c == "{":
count += 1 count += 1
if count == 0: if count == 0:
break break
assert (count == 0) # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i:j + 1] return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """

View File

@ -134,6 +134,7 @@ class BaseChat(ABC):
return payload return payload
def stream_call(self): def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
@ -187,19 +188,19 @@ class BaseChat(ABC):
) )
) )
# ### MOCK # ### MOCK
# ai_response_text = """{ # ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。", # "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。", # "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": { # "command": {
# "name": "histogram-executor", # "name": "histogram-executor",
# "args": { # "args": {
# "title": "订单城市分布柱状图", # "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# } # }
# } # }
# }""" # }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = ( prompt_define_response = (

View File

@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat):
def __list_to_prompt_str(self, list: List) -> str: def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
def generate(self, p) -> str: def generate(self, p) -> str:
return super().generate(p) return super().generate(p)

View File

@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser):
command, thoughts, speak = ( command, thoughts, speak = (
response["command"], response["command"],
response["thoughts"], response["thoughts"],
response["speak"] response["speak"],
) )
return PluginAction(command, speak, thoughts) return PluginAction(command, speak, thoughts)

View File

@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat):
context = context[:2000] context = context[:2000]
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
except NoIndexException: except NoIndexException:
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py") raise ValueError(
"you have no default knowledge store, please execute python knowledge_init.py"
)
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):

View File

@ -5,7 +5,6 @@ import sys
from dotenv import load_dotenv from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42") print("Setting random seed to 42")
random.seed(42) random.seed(42)

View File

@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter(BaseChatAdpter): class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna""" """Model chat Adapter for vicuna"""
def match(self, model_path: str): def match(self, model_path: str):
@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
class CodeT5ChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5""" """Model chat adapter for CodeT5"""
def match(self, model_path: str): def match(self, model_path: str):
@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter):
class CodeGenChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen""" """Model chat adapter for CodeGen"""
def match(self, model_path: str): def match(self, model_path: str):
@ -127,11 +124,22 @@ class GorillaChatAdapter(BaseChatAdpter):
return generate_stream return generate_stream
class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -39,9 +39,13 @@ class ModelWorker:
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"): elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
else: else:
@ -69,7 +73,10 @@ class ModelWorker:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
print("output: ", output) # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
# print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,
@ -79,6 +86,12 @@ class ModelWorker:
except torch.cuda.CudaError: except torch.cuda.CudaError:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
except Exception as e:
ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt): def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt) return get_embeddings(self.model, self.tokenizer, prompt)

View File

@ -667,8 +667,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
# init config # init config
cfg = Config() cfg = Config()
load_native_plugins(cfg) load_native_plugins(cfg)
@ -682,7 +682,7 @@ if __name__ == "__main__":
"pilot.commands.built_in.audio_text", "pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen", "pilot.commands.built_in.image_gen",
] ]
# exclude commands # exclude commands
command_categories = [ command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories x for x in command_categories if x not in cfg.disabled_command_categories
] ]

View File

@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = EncodeTextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register

View File

@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter( textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200 pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
) )
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)

View File

@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from ppt path.""" """Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path) loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register

View File

@ -78,7 +78,7 @@ class DBSummaryClient:
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
table_docs =knowledge_embedding_client.similar_search(query, topk) table_docs = knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs] ans = [d.page_content for d in table_docs]
return ans return ans
@ -147,8 +147,6 @@ class DBSummaryClient:
logger.info("init db profile success...") logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary): def _get_llm_response(query, db_input, dbsummary):
chat_param = { chat_param = {
"temperature": 0.7, "temperature": 0.7,

View File

@ -43,15 +43,14 @@ CFG = Config()
# "tps": 50 # "tps": 50
# } # }
class MysqlSummary(DBSummary): class MysqlSummary(DBSummary):
"""Get mysql summary template.""" """Get mysql summary template."""
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.type = "MYSQL" self.type = "MYSQL"
self.summery = ( self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
)
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
self.vector_tables_info = [] self.vector_tables_info = []
@ -92,9 +91,12 @@ class MysqlSummary(DBSummary):
self.tables[table_name] = table_summary.get_columns() self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns()) self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json()) # self.table_columns_json.append(table_summary.get_summary_json())
table_profile = "table name:{table_name},table description:{table_comment}".format( table_profile = (
table_name=table_name, table_comment=self.db.get_show_create_table(table_name) "table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
) )
)
self.table_columns_json.append(table_profile) self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery()) # self.tables_info.append(table_summary.get_summery())
@ -108,7 +110,11 @@ class MysqlSummary(DBSummary):
def get_db_summery(self): def get_db_summery(self):
return self.summery.format( return self.summery.format(
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000 name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
) )
def get_table_summary(self): def get_table_summary(self):
@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary):
self.indexes_info.append(index_summary.get_summery()) self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format( self.json_summery = self.json_summery_template.format(
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000 name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
) )
def get_summery(self): def get_summery(self):
@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary):
self.bind_fields = index[1] self.bind_fields = index[1]
def get_summery(self): def get_summery(self):
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields) return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -49,6 +49,7 @@ llama-index==0.5.27
pymysql pymysql
unstructured==0.6.3 unstructured==0.6.3
grpcio==1.47.5 grpcio==1.47.5
gpt4all==0.3.0
auto-gpt-plugin-template auto-gpt-plugin-template
pymdown-extensions pymdown-extensions