add llm support for gpt4all (#209)

add llm support for gpt4all #178 #138
This commit is contained in:
magic.chen 2023-06-14 11:35:42 +08:00 committed by GitHub
commit e3f91e98b4
29 changed files with 218 additions and 120 deletions

View File

@ -79,17 +79,21 @@ def load_native_plugins(cfg: Config):
branch_name = cfg.plugins_git_branch branch_name = cfg.plugins_git_branch
native_plugin_repo = "DB-GPT-Plugins" native_plugin_repo = "DB-GPT-Plugins"
url = "https://github.com/csunny/{repo}/archive/{branch}.zip" url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
response = requests.get(url.format(repo=native_plugin_repo, branch=branch_name), response = requests.get(
headers={'Authorization': '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'}) url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200: if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) files = glob.glob(os.path.join(plugins_path_path, f"{native_plugin_repo}*"))
for file in files: for file in files:
os.remove(file) os.remove(file)
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S') time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" file_name = (
f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
)
print(file_name) print(file_name)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
f.write(response.content) f.write(response.content)

View File

@ -66,7 +66,6 @@ class Database:
self._sample_rows_in_table_info = set() self._sample_rows_in_table_info = set()
self._indexes_in_table_info = indexes_in_table_info self._indexes_in_table_info = indexes_in_table_info
@classmethod @classmethod
def from_uri( def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
@ -399,7 +398,6 @@ class Database:
ans = cursor.fetchall() ans = cursor.fetchall()
return ans[0][1] return ans[0][1]
def get_fields(self, table_name): def get_fields(self, table_name):
"""Get column fields about specified table.""" """Get column fields about specified table."""
session = self._db_sessions() session = self._db_sessions()

View File

@ -43,6 +43,7 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
class ClickHouseConnector(RDBMSDatabase): class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector""" """ClickHouseConnector"""
@ -17,19 +18,21 @@ class ClickHouseConnector(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_config(cls) -> RDBMSDatabase:
""" """
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH, CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, db_path: str, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
class DuckDbConnect(RDBMSDatabase): class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData """Connect Duckdb Database fetch MetaData
Args: Args:
@ -20,19 +21,21 @@ class DuckDbConnect(RDBMSDatabase):
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_config(cls) -> RDBMSDatabase:
""" """
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH, CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, db_path: str, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase): class MSSQLConnect(RDBMSDatabase):
"""Connect MSSQL Database fetch MetaData """Connect MSSQL Database fetch MetaData
Args: Args:
@ -19,5 +16,3 @@ class MSSQLConnect(RDBMSDatabase):
driver: str = "pyodbc" driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"] default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]

View File

@ -5,9 +5,6 @@ from typing import Optional, Any
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class MySQLConnect(RDBMSDatabase): class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData """Connect MySQL Database fetch MetaData
Args: Args:
@ -19,5 +16,3 @@ class MySQLConnect(RDBMSDatabase):
driver: str = "pymysql" driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -2,8 +2,10 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase): class OracleConnector(RDBMSDatabase):
"""OracleConnector""" """OracleConnector"""
type: str = "ORACLE" type: str = "ORACLE"
driver: str = "oracle" driver: str = "oracle"

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase): class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector""" """PostgresConnector is a class which Connector"""

View File

@ -57,17 +57,18 @@ CFG = Config()
if __name__ == "__main__": if __name__ == "__main__":
def __extract_json(s): def __extract_json(s):
i = s.index('{') i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == '}': if c == "}":
count -= 1 count -= 1
elif c == '{': elif c == "{":
count += 1 count += 1
if count == 0: if count == 0:
break break
assert (count == 0) # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1] return s[i : j + 1]
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""

View File

@ -41,7 +41,6 @@ class RDBMSDatabase(BaseConnect):
metadata: Optional[MetaData] = None, metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None, ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
): ):
"""Create engine from database URI.""" """Create engine from database URI."""
self._engine = engine self._engine = engine
@ -61,18 +60,37 @@ class RDBMSDatabase(BaseConnect):
Todo password encryption Todo password encryption
Returns: Returns:
""" """
return cls.from_uri_db(cls, return cls.from_uri_db(
cls,
CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD, CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}) engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db(cls, host: str, port: int, user: str, pwd: str, db_name: str = None, def from_uri_db(
engine_args: Optional[dict] = None, **kwargs: Any) -> RDBMSDatabase: cls,
db_url: str = cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str( host: str,
CFG.LOCAL_DB_PORT) port: int,
user: str,
pwd: str,
db_name: str = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.connect_driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
if cls.dialect: if cls.dialect:
db_url = cls.dialect + "+" + db_url db_url = cls.dialect + "+" + db_url
if db_name: if db_name:

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch import torch
import os
from typing import List from typing import List
from functools import cache from functools import cache
from transformers import ( from transformers import (
@ -183,18 +184,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
class GPT4AllAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper):
"""A light version for someone who want practise LLM use laptop.""" """
A light version for someone who want practise LLM use laptop.
All model names see: https://gpt4all.io/models/models.json
"""
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO import gpt4all
pass
if model_path is None and from_pretrained_kwargs.get("model_name") is None:
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
else:
path, file = os.path.split(model_path)
model = gpt4all.GPT4All(model_path=path, model_name=file)
return model, None
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
@ -209,6 +218,7 @@ register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later # just for test_py, remove this later

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":")
print(f"gpt4all, role: {role}, query: {query}")
def worker():
model.generate(prompt=query, streaming=True)
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.01)
t.join()

View File

@ -76,7 +76,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
text = "" text = ""
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
json_data = line.split(b': ', 1)[1] json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8") decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower(): if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data) obj = json.loads(json_data)

View File

@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
@ -121,16 +121,16 @@ class BaseOutputParser(ABC):
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index('{') i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == '}': if c == "}":
count -= 1 count -= 1
elif c == '{': elif c == "{":
count += 1 count += 1
if count == 0: if count == 0:
break break
assert (count == 0) # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1] return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:

View File

@ -134,6 +134,7 @@ class BaseChat(ABC):
return payload return payload
def stream_call(self): def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11

View File

@ -80,7 +80,6 @@ class ChatWithPlugin(BaseChat):
def __list_to_prompt_str(self, list: List) -> str: def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
def generate(self, p) -> str: def generate(self, p) -> str:
return super().generate(p) return super().generate(p)

View File

@ -31,7 +31,7 @@ class PluginChatOutputParser(BaseOutputParser):
command, thoughts, speak = ( command, thoughts, speak = (
response["command"], response["command"],
response["thoughts"], response["thoughts"],
response["speak"] response["speak"],
) )
return PluginAction(command, speak, thoughts) return PluginAction(command, speak, thoughts)

View File

@ -56,7 +56,9 @@ class ChatDefaultKnowledge(BaseChat):
context = context[:2000] context = context[:2000]
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
except NoIndexException: except NoIndexException:
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py") raise ValueError(
"you have no default knowledge store, please execute python knowledge_init.py"
)
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):

View File

@ -5,7 +5,6 @@ import sys
from dotenv import load_dotenv from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42") print("Setting random seed to 42")
random.seed(42) random.seed(42)

View File

@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter(BaseChatAdpter): class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna""" """Model chat Adapter for vicuna"""
def match(self, model_path: str): def match(self, model_path: str):
@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
class CodeT5ChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5""" """Model chat adapter for CodeT5"""
def match(self, model_path: str): def match(self, model_path: str):
@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter):
class CodeGenChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen""" """Model chat adapter for CodeGen"""
def match(self, model_path: str): def match(self, model_path: str):
@ -127,11 +124,22 @@ class GorillaChatAdapter(BaseChatAdpter):
return generate_stream return generate_stream
class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -39,9 +39,13 @@ class ModelWorker:
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"): elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
else: else:
@ -69,7 +73,10 @@ class ModelWorker:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
print("output: ", output) # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
# print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,
@ -79,6 +86,12 @@ class ModelWorker:
except torch.cuda.CudaError: except torch.cuda.CudaError:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
except Exception as e:
ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt): def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt) return get_embeddings(self.model, self.tokenizer, prompt)

View File

@ -30,7 +30,11 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = EncodeTextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register

View File

@ -29,7 +29,9 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter( textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200 pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
) )
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)

View File

@ -25,7 +25,11 @@ class PPTEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from ppt path.""" """Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path) loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register

View File

@ -147,8 +147,6 @@ class DBSummaryClient:
logger.info("init db profile success...") logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary): def _get_llm_response(query, db_input, dbsummary):
chat_param = { chat_param = {
"temperature": 0.7, "temperature": 0.7,

View File

@ -43,15 +43,14 @@ CFG = Config()
# "tps": 50 # "tps": 50
# } # }
class MysqlSummary(DBSummary): class MysqlSummary(DBSummary):
"""Get mysql summary template.""" """Get mysql summary template."""
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.type = "MYSQL" self.type = "MYSQL"
self.summery = ( self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
)
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
self.vector_tables_info = [] self.vector_tables_info = []
@ -92,8 +91,11 @@ class MysqlSummary(DBSummary):
self.tables[table_name] = table_summary.get_columns() self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns()) self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json()) # self.table_columns_json.append(table_summary.get_summary_json())
table_profile = "table name:{table_name},table description:{table_comment}".format( table_profile = (
table_name=table_name, table_comment=self.db.get_show_create_table(table_name) "table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
)
) )
self.table_columns_json.append(table_profile) self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery()) # self.tables_info.append(table_summary.get_summery())
@ -108,7 +110,11 @@ class MysqlSummary(DBSummary):
def get_db_summery(self): def get_db_summery(self):
return self.summery.format( return self.summery.format(
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000 name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
) )
def get_table_summary(self): def get_table_summary(self):
@ -153,7 +159,12 @@ class MysqlTableSummary(TableSummary):
self.indexes_info.append(index_summary.get_summery()) self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format( self.json_summery = self.json_summery_template.format(
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000 name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
) )
def get_summery(self): def get_summery(self):
@ -203,7 +214,9 @@ class MysqlIndexSummary(IndexSummary):
self.bind_fields = index[1] self.bind_fields = index[1]
def get_summery(self): def get_summery(self):
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields) return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -49,6 +49,7 @@ llama-index==0.5.27
pymysql pymysql
unstructured==0.6.3 unstructured==0.6.3
grpcio==1.47.5 grpcio==1.47.5
gpt4all==0.3.0
auto-gpt-plugin-template auto-gpt-plugin-template
pymdown-extensions pymdown-extensions