diff --git a/README.md b/README.md
index 88b488aea..06d9e8022 100644
--- a/README.md
+++ b/README.md
@@ -188,7 +188,7 @@ The core capabilities mainly consist of the following parts:


-[**Quickstart**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html)
+[**Installation && Usage Tutorial**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html)
### Language Switching
In the .env configuration file, modify the LANGUAGE parameter to switch to different languages. The default is English (Chinese: zh, English: en, other languages to be added later).
@@ -214,8 +214,17 @@ The core capabilities mainly consist of the following parts:
- [ ] Images
- [x] RAG
-- [ ] KnownledgeGraph
-
+- [ ] Graph Database
+ - [ ] Neo4j Graph
+ - [ ] Nebula Graph
+- [x] Multi Vector Database
+ - [x] Chroma
+ - [x] Milvus
+ - [x] Weaviate
+ - [x] PGVector
+ - [ ] Elasticsearch
+ - [ ] ClickHouse
+ - [ ] Faiss
### Multi Datasource Support
- Multi Datasource Support
@@ -239,9 +248,9 @@ The core capabilities mainly consist of the following parts:
- [ ] StarRocks
### Multi-Models And vLLM
-- [x] [cluster deployment](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)
-- [x] [fastchat support](https://github.com/lm-sys/FastChat)
-- [x] [vLLM support](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html)
+- [x] [Cluster Deployment](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)
+- [x] [Fastchat Support](https://github.com/lm-sys/FastChat)
+- [x] [vLLM Support](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html)
### Agents market and Plugins
- [x] multi-agents framework
diff --git a/README.zh.md b/README.zh.md
index 3cfaa6ed5..f1a73f063 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -258,7 +258,17 @@ The MIT License (MIT)
- [ ] Code
- [ ] Images
- [x] RAG
-- [ ] KnownledgeGraph
+- [ ] Graph Database
+ - [ ] Neo4j Graph
+ - [ ] Nebula Graph
+- [x] Multi Vector Database
+ - [x] Chroma
+ - [x] Milvus
+ - [x] Weaviate
+ - [x] PGVector
+ - [ ] Elasticsearch
+ - [ ] ClickHouse
+ - [ ] Faiss
### 多数据源支持
@@ -286,6 +296,7 @@ The MIT License (MIT)
### 多模型管理与推理优化
- [x] [集群部署](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html)
- [x] [fastchat支持](https://github.com/lm-sys/FastChat)
+- [x] [fastchat支持](https://github.com/lm-sys/FastChat)
- [x] [vLLM 支持](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html)
### Agents与插件市场
diff --git a/examples/proxy_example.py b/examples/proxy_example.py
index 5d2f8e5db..a3d2f3bc4 100644
--- a/examples/proxy_example.py
+++ b/examples/proxy_example.py
@@ -7,55 +7,61 @@ import hashlib
from http import HTTPStatus
from dashscope import Generation
+
def call_with_messages():
- messages = [{'role': 'system', 'content': '你是生活助手机器人。'},
- {'role': 'user', 'content': '如何做西红柿鸡蛋?'}]
+ messages = [
+ {"role": "system", "content": "你是生活助手机器人。"},
+ {"role": "user", "content": "如何做西红柿鸡蛋?"},
+ ]
gen = Generation()
response = gen.call(
Generation.Models.qwen_turbo,
messages=messages,
stream=True,
top_p=0.8,
- result_format='message', # set the result to be "message" format.
+ result_format="message", # set the result to be "message" format.
)
for response in response:
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
- if response.status_code == HTTPStatus.OK:
- print(response.output) # The output text
+ if response.status_code == HTTPStatus.OK:
+ print(response.output) # The output text
print(response.usage) # The usage information
else:
- print(response.code) # The error code.
- print(response.message) # The error message.
-
+ print(response.code) # The error code.
+ print(response.message) # The error message.
def build_access_token(api_key: str, secret_key: str) -> str:
"""
- Generate Access token according AK, SK
+ Generate Access token according AK, SK
"""
-
+
url = "https://aip.baidubce.com/oauth/2.0/token"
- params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
+ params = {
+ "grant_type": "client_credentials",
+ "client_id": api_key,
+ "client_secret": secret_key,
+ }
res = requests.get(url=url, params=params)
-
+
if res.status_code == 200:
return res.json().get("access_token")
def _calculate_md5(text: str) -> str:
-
md5 = hashlib.md5()
md5.update(text.encode("utf-8"))
encrypted = md5.hexdigest()
return encrypted
+
def baichuan_call():
url = "https://api.baichuan-ai.com/v1/stream/chat"
-
-
-if __name__ == '__main__':
- call_with_messages()
\ No newline at end of file
+
+
+if __name__ == "__main__":
+ call_with_messages()
diff --git a/pilot/base_modules/agent/__init__.py b/pilot/base_modules/agent/__init__.py
index e7bbb7b7b..60f0489da 100644
--- a/pilot/base_modules/agent/__init__.py
+++ b/pilot/base_modules/agent/__init__.py
@@ -1,4 +1,4 @@
-from .db.my_plugin_db import MyPluginEntity, MyPluginDao
+from .db.my_plugin_db import MyPluginEntity, MyPluginDao
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
from .commands.command import execute_command, get_command
@@ -8,4 +8,4 @@ from .commands.disply_type.show_chart_gen import static_message_img_path
from .common.schema import Status, PluginStorageType
from .commands.command_mange import ApiCall
-from .commands.command import execute_command
\ No newline at end of file
+from .commands.command import execute_command
diff --git a/pilot/base_modules/agent/commands/command.py b/pilot/base_modules/agent/commands/command.py
index a3202f67d..bd5806ec0 100644
--- a/pilot/base_modules/agent/commands/command.py
+++ b/pilot/base_modules/agent/commands/command.py
@@ -9,6 +9,7 @@ from .generator import PluginPromptGenerator
from pilot.configs.config import Config
+
def _resolve_pathlike_command_args(command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:
# todo
@@ -64,8 +65,6 @@ def execute_ai_response_json(
return result
-
-
def execute_command(
command_name: str,
arguments,
@@ -81,10 +80,8 @@ def execute_command(
str: The result of the command
"""
-
cmd = plugin_generator.command_registry.commands.get(command_name)
-
# If the command is found, call it with the provided arguments
if cmd:
try:
@@ -153,6 +150,3 @@ def get_command(response_json: Dict):
# All other errors, return "Error: + error message"
except Exception as e:
return "Error:", str(e)
-
-
-
diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py
index 7b8f8cc06..be9e02811 100644
--- a/pilot/base_modules/agent/commands/command_mange.py
+++ b/pilot/base_modules/agent/commands/command_mange.py
@@ -28,13 +28,13 @@ class Command:
"""
def __init__(
- self,
- name: str,
- description: str,
- method: Callable[..., Any],
- signature: str = "",
- enabled: bool = True,
- disabled_reason: Optional[str] = None,
+ self,
+ name: str,
+ description: str,
+ method: Callable[..., Any],
+ signature: str = "",
+ enabled: bool = True,
+ disabled_reason: Optional[str] = None,
):
self.name = name
self.description = description
@@ -87,11 +87,12 @@ class CommandRegistry:
if hasattr(reloaded_module, "register"):
reloaded_module.register(self)
- def is_valid_command(self, name:str)-> bool:
+ def is_valid_command(self, name: str) -> bool:
if name not in self.commands:
return False
else:
return True
+
def get_command(self, name: str) -> Callable[..., Any]:
return self.commands[name]
@@ -129,23 +130,23 @@ class CommandRegistry:
attr = getattr(module, attr_name)
# Register decorated functions
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
- attr, AUTO_GPT_COMMAND_IDENTIFIER
+ attr, AUTO_GPT_COMMAND_IDENTIFIER
):
self.register(attr.command)
# Register command classes
elif (
- inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
+ inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
):
cmd_instance = attr()
self.register(cmd_instance)
def command(
- name: str,
- description: str,
- signature: str = "",
- enabled: bool = True,
- disabled_reason: Optional[str] = None,
+ name: str,
+ description: str,
+ signature: str = "",
+ enabled: bool = True,
+ disabled_reason: Optional[str] = None,
) -> Callable[..., Any]:
"""The command decorator is used to create Command objects from ordinary functions."""
@@ -241,34 +242,60 @@ class ApiCall:
else:
md_tag_end = "```"
for tag in error_md_tags:
- all_context = all_context.replace(tag + api_context + md_tag_end, api_context)
- all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context)
- all_context = all_context.replace(tag + " " +api_context + " " + md_tag_end, api_context)
+ all_context = all_context.replace(
+ tag + api_context + md_tag_end, api_context
+ )
+ all_context = all_context.replace(
+ tag + "\n" + api_context + "\n" + md_tag_end, api_context
+ )
+ all_context = all_context.replace(
+ tag + " " + api_context + " " + md_tag_end, api_context
+ )
all_context = all_context.replace(tag + api_context, api_context)
return all_context
def api_view_context(self, all_context: str, display_mode: bool = False):
error_mk_tags = ["```", "```python", "```xml"]
- call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
+ call_context_map = extract_content_open_ending(
+ all_context, self.agent_prefix, self.agent_end, True
+ )
for api_index, api_context in call_context_map.items():
api_status = self.plugin_status_map.get(api_context)
if api_status is not None:
if display_mode:
if api_status.api_result:
- all_context = self.__deal_error_md_tags(all_context, api_context)
- all_context = all_context.replace(api_context, api_status.api_result)
+ all_context = self.__deal_error_md_tags(
+ all_context, api_context
+ )
+ all_context = all_context.replace(
+ api_context, api_status.api_result
+ )
else:
if api_status.status == Status.FAILED.value:
- all_context = self.__deal_error_md_tags(all_context, api_context)
- all_context = all_context.replace(api_context, f"""\nERROR!{api_status.err_msg}\n """)
+ all_context = self.__deal_error_md_tags(
+ all_context, api_context
+ )
+ all_context = all_context.replace(
+ api_context,
+ f"""\nERROR!{api_status.err_msg}\n """,
+ )
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
- all_context = self.__deal_error_md_tags(all_context, api_context)
- all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n')
+ all_context = self.__deal_error_md_tags(
+ all_context, api_context
+ )
+ all_context = all_context.replace(
+ api_context,
+ f'\nWaiting...{cost_str}S\n',
+ )
else:
- all_context = self.__deal_error_md_tags(all_context, api_context, False)
- all_context = all_context.replace(api_context, self.to_view_text(api_status))
+ all_context = self.__deal_error_md_tags(
+ all_context, api_context, False
+ )
+ all_context = all_context.replace(
+ api_context, self.to_view_text(api_status)
+ )
else:
# not ready api call view change
@@ -276,27 +303,34 @@ class ApiCall:
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
for tag in error_mk_tags:
- all_context = all_context.replace(tag + api_context , api_context)
- all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n')
+ all_context = all_context.replace(tag + api_context, api_context)
+ all_context = all_context.replace(
+ api_context,
+ f'\nWaiting...{cost_str}S\n',
+ )
return all_context
def update_from_context(self, all_context):
- api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
+ api_context_map = extract_content(
+ all_context, self.agent_prefix, self.agent_end, True
+ )
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
- api_name = api_call_element.find('name').text
- if api_name.find("[")>=0 or api_name.find("]")>=0:
+ api_name = api_call_element.find("name").text
+ if api_name.find("[") >= 0 or api_name.find("]") >= 0:
api_name = api_name.replace("[", "").replace("]", "")
api_args = {}
- args_elements = api_call_element.find('args')
+ args_elements = api_call_element.find("args")
for child_element in args_elements.iter():
api_args[child_element.tag] = child_element.text
api_status = self.plugin_status_map.get(api_context)
if api_status is None:
- api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
+ api_status = PluginStatus(
+ name=api_name, location=[api_index], args=api_args
+ )
self.plugin_status_map[api_context] = api_status
else:
api_status.location.append(api_index)
@@ -304,20 +338,20 @@ class ApiCall:
def __to_view_param_str(self, api_status):
param = {}
if api_status.name:
- param['name'] = api_status.name
- param['status'] = api_status.status
+ param["name"] = api_status.name
+ param["status"] = api_status.status
if api_status.logo_url:
- param['logo'] = api_status.logo_url
+ param["logo"] = api_status.logo_url
if api_status.err_msg:
- param['err_msg'] = api_status.err_msg
+ param["err_msg"] = api_status.err_msg
if api_status.api_result:
- param['result'] = api_status.api_result
+ param["result"] = api_status.api_result
return json.dumps(param)
def to_view_text(self, api_status: PluginStatus):
- api_call_element = ET.Element('dbgpt-view')
+ api_call_element = ET.Element("dbgpt-view")
api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
@@ -332,7 +366,9 @@ class ApiCall:
value.status = Status.RUNNING.value
logging.info(f"插件执行:{value.name},{value.args}")
try:
- value.api_result = execute_command(value.name, value.args, self.plugin_generator)
+ value.api_result = execute_command(
+ value.name, value.args, self.plugin_generator
+ )
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
@@ -350,15 +386,19 @@ class ApiCall:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
- sql = value.args['sql']
+ sql = value.args["sql"]
if sql:
param = {
"df": sql_run_func(sql),
}
if self.display_registry.is_valid_command(value.name):
- value.api_result = self.display_registry.call(value.name, **param)
+ value.api_result = self.display_registry.call(
+ value.name, **param
+ )
else:
- value.api_result = self.display_registry.call("response_table", **param)
+ value.api_result = self.display_registry.call(
+ "response_table", **param
+ )
value.status = Status.COMPLETED.value
except Exception as e:
@@ -366,4 +406,3 @@ class ApiCall:
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)
-
diff --git a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
index 7807f3818..166992822 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
@@ -15,6 +15,7 @@ from matplotlib.font_manager import FontManager
from pilot.common.string_utils import is_scientific_notation
import logging
+
logger = logging.getLogger(__name__)
@@ -88,12 +89,13 @@ def zh_font_set():
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
+
def format_axis(value, pos):
# 判断是否为数字
if is_scientific_notation(value):
# 判断是否需要进行非科学计数法格式化
- return '{:.2f}'.format(value)
+ return "{:.2f}".format(value)
return value
@@ -102,7 +104,7 @@ def format_axis(value, pos):
"Line chart display, used to display comparative trend analysis data",
'"df":""',
)
-def response_line_chart( df: DataFrame) -> str:
+def response_line_chart(df: DataFrame) -> str:
logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data!")
@@ -143,9 +145,15 @@ def response_line_chart( df: DataFrame) -> str:
if len(num_colmns) > 0:
num_colmns.append(y)
df_melted = pd.melt(
- df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
+ df,
+ id_vars=x,
+ value_vars=num_colmns,
+ var_name="line",
+ value_name="Value",
+ )
+ sns.lineplot(
+ data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2"
)
- sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
@@ -154,7 +162,7 @@ def response_line_chart( df: DataFrame) -> str:
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
- plt.savefig(chart_path, dpi=100, transparent=True)
+ plt.savefig(chart_path, dpi=100, transparent=True)
html_img = f"""
"""
return html_img
@@ -168,7 +176,7 @@ def response_line_chart( df: DataFrame) -> str:
"Histogram, suitable for comparative analysis of multiple target values",
'"df":""',
)
-def response_bar_chart( df: DataFrame) -> str:
+def response_bar_chart(df: DataFrame) -> str:
logger.info(f"response_bar_chart")
if df.size <= 0:
raise ValueError("No Data!")
@@ -246,7 +254,7 @@ def response_bar_chart( df: DataFrame) -> str:
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
- plt.savefig(chart_path, dpi=100,transparent=True)
+ plt.savefig(chart_path, dpi=100, transparent=True)
html_img = f"""
"""
return html_img
diff --git a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
index 9afd14ca5..d11a00c7a 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
@@ -3,8 +3,10 @@ from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import logging
+
logger = logging.getLogger(__name__)
+
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
diff --git a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py
index 75400fd97..b58ca5843 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py
@@ -3,6 +3,7 @@ from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import logging
+
logger = logging.getLogger(__name__)
@@ -22,7 +23,7 @@ def response_data_text(df: DataFrame) -> str:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""{table_str}
"""
- text_info = html.replace("\n", " ")
+ text_info = html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:
diff --git a/pilot/base_modules/agent/commands/generator.py b/pilot/base_modules/agent/commands/generator.py
index 310551481..d1dded6fe 100644
--- a/pilot/base_modules/agent/commands/generator.py
+++ b/pilot/base_modules/agent/commands/generator.py
@@ -133,5 +133,3 @@ class PluginPromptGenerator:
def generate_commands_string(self) -> str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
-
-
diff --git a/pilot/base_modules/agent/common/schema.py b/pilot/base_modules/agent/common/schema.py
index 5c36e4da6..87ba196b0 100644
--- a/pilot/base_modules/agent/common/schema.py
+++ b/pilot/base_modules/agent/common/schema.py
@@ -5,13 +5,14 @@ class PluginStorageType(Enum):
Git = "git"
Oss = "oss"
+
class Status(Enum):
TODO = "todo"
- RUNNING = 'running'
- FAILED = 'failed'
- COMPLETED = 'completed'
+ RUNNING = "running"
+ FAILED = "failed"
+ COMPLETED = "completed"
class ApiTagType(Enum):
API_VIEW = "dbgpt_view"
- API_CALL = "dbgpt_call"
\ No newline at end of file
+ API_CALL = "dbgpt_call"
diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py
index 4bca3dee8..47532a66b 100644
--- a/pilot/base_modules/agent/controller.py
+++ b/pilot/base_modules/agent/controller.py
@@ -16,7 +16,13 @@ from pilot.openapi.api_view_model import (
Result,
)
-from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter
+from .model import (
+ PluginHubParam,
+ PagenationFilter,
+ PagenationResult,
+ PluginHubFilter,
+ MyPluginFilter,
+)
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .plugins_util import scan_plugins
@@ -33,17 +39,18 @@ class ModuleAgent(BaseComponent, ABC):
name = ComponentType.AGENT_HUB
def __init__(self):
- #load plugins
+ # load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
-
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
- def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator:
+ def load_select_plugin(
+ self, generator: PluginPromptGenerator, select_plugins: List[str]
+ ) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
@@ -53,6 +60,7 @@ class ModuleAgent(BaseComponent, ABC):
generator = plugin.post_prompt(generator)
return generator
+
module_agent = ModuleAgent()
@@ -61,25 +69,28 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
- agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
+ agent_hub.refresh_hub_from_git(
+ update_param.url, update_param.branch, update_param.authorization
+ )
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.faild(code="E0020", msg=f"Agent Hub Update Error! {e}")
-
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
- filter_enetity:PluginHubEntity = PluginHubEntity()
+ filter_enetity: PluginHubEntity = PluginHubEntity()
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
- datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size)
+ datas, total_pages, total_count = agent_hub.hub_dao.list(
+ filter_enetity, filter.page_index, filter.page_size
+ )
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
@@ -89,11 +100,12 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
+
@router.post("/v1/agent/my", response_model=Result[str])
-async def my_agents(user:str= None):
+async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
- agents = agent_hub.get_my_plugin(user)
+ agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
@@ -102,7 +114,7 @@ async def my_agents(user:str= None):
@router.post("/v1/agent/install", response_model=Result[str])
-async def agent_install(plugin_name:str, user: str = None):
+async def agent_install(plugin_name: str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
@@ -111,14 +123,13 @@ async def agent_install(plugin_name:str, user: str = None):
module_agent.refresh_plugins()
return Result.succ(None)
- except Exception as e:
+ except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
-
@router.post("/v1/agent/uninstall", response_model=Result[str])
-async def agent_uninstall(plugin_name:str, user: str = None):
+async def agent_uninstall(plugin_name: str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
@@ -126,19 +137,18 @@ async def agent_uninstall(plugin_name:str, user: str = None):
module_agent.refresh_plugins()
return Result.succ(None)
- except Exception as e:
+ except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
-async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
+async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
await agent_hub.upload_my_plugin(doc_file, user)
return Result.succ(None)
- except Exception as e:
+ except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")
-
diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py
index f5620fc79..8507bcc8e 100644
--- a/pilot/base_modules/agent/db/plugin_hub_db.py
+++ b/pilot/base_modules/agent/db/plugin_hub_db.py
@@ -10,8 +10,10 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session
class PluginHubEntity(Base):
- __tablename__ = 'plugin_hub'
- id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
+ __tablename__ = "plugin_hub"
+ id = Column(
+ Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
+ )
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
description = Column(String(255), nullable=False, comment="plugin description")
author = Column(String(255), nullable=True, comment="plugin author")
@@ -25,8 +27,8 @@ class PluginHubEntity(Base):
installed = Column(Integer, default=False, comment="plugin already installed count")
__table_args__ = (
- UniqueConstraint('name', name="uk_name"),
- Index('idx_q_type', 'type'),
+ UniqueConstraint("name", name="uk_name"),
+ Index("idx_q_type", "type"),
)
@@ -38,7 +40,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def add(self, engity: PluginHubEntity):
session = self.get_session()
- timezone = pytz.timezone('Asia/Shanghai')
+ timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
author=engity.author,
@@ -64,7 +66,9 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
finally:
session.close()
- def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
+ def list(
+ self, query: PluginHubEntity, page=1, page_size=20
+ ) -> list[PluginHubEntity]:
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
@@ -72,17 +76,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.name == query.name
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.type == query.type
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.author == query.author
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
@@ -110,9 +108,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.name == name
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result
@@ -123,17 +119,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.name == query.name
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.type == query.type
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.author == query.author
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
@@ -148,9 +138,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
if plugin_id is not None:
- plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.id == plugin_id
- )
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
plugin_hubs.delete()
session.commit()
session.close()
diff --git a/pilot/base_modules/agent/model.py b/pilot/base_modules/agent/model.py
index 8c02e4af4..eed2cb420 100644
--- a/pilot/base_modules/agent/model.py
+++ b/pilot/base_modules/agent/model.py
@@ -3,32 +3,35 @@ from dataclasses import dataclass
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
-T = TypeVar('T')
+T = TypeVar("T")
+
class PagenationFilter(BaseModel, Generic[T]):
page_index: int = 1
- page_size: int = 20
+ page_size: int = 20
filter: T = None
+
class PagenationResult(BaseModel, Generic[T]):
page_index: int = 1
- page_size: int = 20
+ page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
def to_dic(self):
- data_dicts =[]
+ data_dicts = []
for item in self.datas:
data_dicts.append(item.__dict__)
return {
- 'page_index': self.page_index,
- 'page_size': self.page_size,
- 'total_page': self.total_page,
- 'total_row_count': self.total_row_count,
- 'datas': data_dicts
+ "page_index": self.page_index,
+ "page_size": self.page_size,
+ "total_page": self.total_page,
+ "total_row_count": self.total_row_count,
+ "datas": data_dicts,
}
+
@dataclass
class PluginHubFilter(BaseModel):
name: str
@@ -53,9 +56,14 @@ class MyPluginFilter(BaseModel):
class PluginHubParam(BaseModel):
- channel: Optional[str] = Field("git", description="Plugin storage channel")
- url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url")
- branch: Optional[str] = Field("main", description="github download branch", nullable=True)
- authorization: Optional[str] = Field(None, description="github download authorization", nullable=True)
-
-
+ channel: Optional[str] = Field("git", description="Plugin storage channel")
+ url: Optional[str] = Field(
+ "https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
+ description="Plugin storage url",
+ )
+ branch: Optional[str] = Field(
+ "main", description="github download branch", nullable=True
+ )
+ authorization: Optional[str] = Field(
+ None, description="github download authorization", nullable=True
+ )
diff --git a/pilot/base_modules/agent/plugins_util.py b/pilot/base_modules/agent/plugins_util.py
index b0a9c7885..facc1472d 100644
--- a/pilot/base_modules/agent/plugins_util.py
+++ b/pilot/base_modules/agent/plugins_util.py
@@ -117,7 +117,7 @@ def load_native_plugins(cfg: Config):
t.start()
-def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
+def __scan_plugin_file(file_path, debug: bool = False) -> List[AutoGPTPluginTemplate]:
logger.info(f"__scan_plugin_file:{file_path},{debug}")
loaded_plugins = []
if moduleList := inspect_zip_for_modules(str(file_path), debug):
@@ -133,14 +133,17 @@ def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTempl
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
- "_abc_impl" in a_keys
- and a_module.__name__ != "AutoGPTPluginTemplate"
- # and denylist_allowlist_check(a_module.__name__, cfg)
+ "_abc_impl" in a_keys
+ and a_module.__name__ != "AutoGPTPluginTemplate"
+ # and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
return loaded_plugins
-def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
+
+def scan_plugins(
+ plugins_file_path: str, file_name: str = "", debug: bool = False
+) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
@@ -159,7 +162,7 @@ def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = Fals
loaded_plugins = __scan_plugin_file(plugin_path)
else:
for plugin_path in plugins_path.glob("*.zip"):
- loaded_plugins.extend(__scan_plugin_file(plugin_path))
+ loaded_plugins.extend(__scan_plugin_file(plugin_path))
if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
@@ -192,17 +195,23 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
return ack.lower() == cfg.authorise_key
-def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
- authorization: str = None):
+def update_from_git(
+ download_path: str,
+ github_repo: str = "",
+ branch_name: str = "main",
+ authorization: str = None,
+):
os.makedirs(download_path, exist_ok=True)
if github_repo:
if github_repo.index("github.com") <= 0:
raise ValueError("Not a correct Github repository address!" + github_repo)
github_repo = github_repo.replace(".git", "")
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
- plugin_repo_name = github_repo.strip('/').split('/')[-1]
+ plugin_repo_name = github_repo.strip("/").split("/")[-1]
else:
- url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
+ url = (
+ "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
+ )
plugin_repo_name = "DB-GPT-Plugins"
try:
session = requests.Session()
@@ -216,14 +225,14 @@ def update_from_git(download_path: str, github_repo: str = "", branch_name: str
if response.status_code == 200:
plugins_path_path = Path(download_path)
- files = glob.glob(
- os.path.join(plugins_path_path, f"{plugin_repo_name}*")
- )
+ files = glob.glob(os.path.join(plugins_path_path, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
- file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
+ file_name = (
+ f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
+ )
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)
diff --git a/pilot/base_modules/mange_base_api.py b/pilot/base_modules/mange_base_api.py
index c0b5da273..57f2a27e7 100644
--- a/pilot/base_modules/mange_base_api.py
+++ b/pilot/base_modules/mange_base_api.py
@@ -1,7 +1,6 @@
class ModuleMangeApi:
-
def module_name(self):
pass
def register(self):
- pass
\ No newline at end of file
+ pass
diff --git a/pilot/base_modules/meta_data/base_dao.py b/pilot/base_modules/meta_data/base_dao.py
index 330bed592..693fe5699 100644
--- a/pilot/base_modules/meta_data/base_dao.py
+++ b/pilot/base_modules/meta_data/base_dao.py
@@ -1,11 +1,16 @@
from typing import TypeVar, Generic, List, Any
from sqlalchemy.orm import sessionmaker
-T = TypeVar('T')
+T = TypeVar("T")
+
class BaseDao(Generic[T]):
def __init__(
- self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None,
+ self,
+ orm_base=None,
+ database: str = None,
+ db_engine: Any = None,
+ session: Any = None,
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base
diff --git a/pilot/base_modules/meta_data/meta_data.py b/pilot/base_modules/meta_data/meta_data.py
index 46a5f44b0..e5551da82 100644
--- a/pilot/base_modules/meta_data/meta_data.py
+++ b/pilot/base_modules/meta_data/meta_data.py
@@ -7,7 +7,7 @@ import fnmatch
from datetime import datetime
from typing import Optional, Type, TypeVar
-from sqlalchemy import create_engine,DateTime, String, func, MetaData
+from sqlalchemy import create_engine, DateTime, String, func, MetaData
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
@@ -32,16 +32,17 @@ db_name = "dbgpt"
db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
-if CFG.LOCAL_DB_TYPE == 'mysql':
- engine_temp = create_engine(f"mysql+pymysql://"
- + quote(CFG.LOCAL_DB_USER)
- + ":"
- + quote(CFG.LOCAL_DB_PASSWORD)
- + "@"
- + CFG.LOCAL_DB_HOST
- + ":"
- + str(CFG.LOCAL_DB_PORT)
- )
+if CFG.LOCAL_DB_TYPE == "mysql":
+ engine_temp = create_engine(
+ f"mysql+pymysql://"
+ + quote(CFG.LOCAL_DB_USER)
+ + ":"
+ + quote(CFG.LOCAL_DB_PASSWORD)
+ + "@"
+ + CFG.LOCAL_DB_HOST
+ + ":"
+ + str(CFG.LOCAL_DB_PORT)
+ )
# check and auto create mysqldatabase
try:
# try to connect
@@ -53,20 +54,19 @@ if CFG.LOCAL_DB_TYPE == 'mysql':
# if connect failed, create dbgpt database
logger.error(f"{db_name} not connect success!")
- engine = create_engine(f"mysql+pymysql://"
- + quote(CFG.LOCAL_DB_USER)
- + ":"
- + quote(CFG.LOCAL_DB_PASSWORD)
- + "@"
- + CFG.LOCAL_DB_HOST
- + ":"
- + str(CFG.LOCAL_DB_PORT)
- + f"/{db_name}"
- )
+ engine = create_engine(
+ f"mysql+pymysql://"
+ + quote(CFG.LOCAL_DB_USER)
+ + ":"
+ + quote(CFG.LOCAL_DB_PASSWORD)
+ + "@"
+ + CFG.LOCAL_DB_HOST
+ + ":"
+ + str(CFG.LOCAL_DB_PORT)
+ + f"/{db_name}"
+ )
else:
- engine = create_engine(f'sqlite:///{db_path}')
-
-
+ engine = create_engine(f"sqlite:///{db_path}")
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -81,16 +81,16 @@ Base = declarative_base()
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = AlembicConfig(alembic_ini_path)
-alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
+alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
os.makedirs(default_db_path + "/alembic", exist_ok=True)
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
-alembic_cfg.set_main_option('script_location', default_db_path + "/alembic")
+alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
# 将模型和会话传递给Alembic配置
-alembic_cfg.attributes['target_metadata'] = Base.metadata
-alembic_cfg.attributes['session'] = session
+alembic_cfg.attributes["target_metadata"] = Base.metadata
+alembic_cfg.attributes["session"] = session
# # 创建表
@@ -106,7 +106,7 @@ def ddl_init_and_upgrade():
# command.upgrade(alembic_cfg, 'head')
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
with engine.connect() as connection:
- alembic_cfg.attributes['connection'] = connection
+ alembic_cfg.attributes["connection"] = connection
heads = command.heads(alembic_cfg)
print("heads:" + str(heads))
diff --git a/pilot/base_modules/module_factory.py b/pilot/base_modules/module_factory.py
index 139597f9c..8b1378917 100644
--- a/pilot/base_modules/module_factory.py
+++ b/pilot/base_modules/module_factory.py
@@ -1,2 +1 @@
-
diff --git a/pilot/common/string_utils.py b/pilot/common/string_utils.py
index 14bf5082e..170f0519a 100644
--- a/pilot/common/string_utils.py
+++ b/pilot/common/string_utils.py
@@ -1,27 +1,30 @@
import re
+
def is_all_chinese(text):
### Determine whether the string is pure Chinese
- pattern = re.compile(r'^[一-龥]+$')
+ pattern = re.compile(r"^[一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_number_chinese(text):
### Determine whether the string is numbers and Chinese
- pattern = re.compile(r'^[\d一-龥]+$')
+ pattern = re.compile(r"^[\d一-龥]+$")
match = re.match(pattern, text)
return match is not None
+
def is_chinese_include_number(text):
### Determine whether the string is pure Chinese or Chinese containing numbers
- pattern = re.compile(r'^[一-龥]+[\d一-龥]*$')
+ pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
match = re.match(pattern, text)
return match is not None
+
def is_scientific_notation(string):
# 科学计数法的正则表达式
- pattern = r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$'
+ pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
# 使用正则表达式匹配字符串
match = re.match(pattern, str(string))
# 判断是否匹配成功
@@ -30,28 +33,30 @@ def is_scientific_notation(string):
else:
return False
+
def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
- match_map ={}
+ match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
- extracted_content = long_string[start_index:end_index + len(s2)]
+ extracted_content = long_string[start_index : end_index + len(s2)]
else:
end_index = long_string.find(s2, start_index + len(s1))
- extracted_content = long_string[start_index + len(s1):end_index]
+ extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
+
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
# extract text open ending
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
- if long_string.find(s2, start_index) <=0:
+ if long_string.find(s2, start_index) <= 0:
end_index = len(long_string)
else:
if is_include:
@@ -59,19 +64,18 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
else:
end_index = long_string.find(s2, start_index + len(s1))
if is_include:
- extracted_content = long_string[start_index:end_index + len(s2)]
+ extracted_content = long_string[start_index : end_index + len(s2)]
else:
- extracted_content = long_string[start_index + len(s1):end_index]
+ extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
- start_index= long_string.find(s1, start_index + 1)
+ start_index = long_string.find(s1, start_index + 1)
return match_map
-
-if __name__=="__main__":
+if __name__ == "__main__":
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123"
s2 = "456"
- print(extract_content_open_ending(s, s1, s2, True))
\ No newline at end of file
+ print(extract_content_open_ending(s, s1, s2, True))
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index 0db89b70b..a25462c5e 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -55,20 +55,24 @@ class Config(metaclass=Singleton):
self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY")
if self.tongyi_proxy_api_key:
os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key
-
+
# zhipu
self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY")
if self.zhipu_proxy_api_key:
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
- os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv("ZHIPU_MODEL_VERSION")
+ os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
+ "ZHIPU_MODEL_VERSION"
+ )
# wenxin
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
- self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY")
+ self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY")
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
- os.environ["wenxin_proxyllm_proxy_api_secret"] = self.wenxin_proxy_api_secret
+ os.environ[
+ "wenxin_proxyllm_proxy_api_secret"
+ ] = self.wenxin_proxy_api_secret
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
# xunfei spark
@@ -90,8 +94,7 @@ class Config(metaclass=Singleton):
os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
-
-
+
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
@@ -172,7 +175,6 @@ class Config(metaclass=Singleton):
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
)
-
self.LOCAL_DB_MANAGE = None
###dbgpt meta info database connection configuration
@@ -190,7 +192,6 @@ class Config(metaclass=Singleton):
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb")
-
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
diff --git a/pilot/connections/__init__.py b/pilot/connections/__init__.py
index ce13a69f3..8cc9799db 100644
--- a/pilot/connections/__init__.py
+++ b/pilot/connections/__init__.py
@@ -1 +1 @@
-from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
\ No newline at end of file
+from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
diff --git a/pilot/connections/manages/connect_config_db.py b/pilot/connections/manages/connect_config_db.py
index 42307b243..7443d18ad 100644
--- a/pilot/connections/manages/connect_config_db.py
+++ b/pilot/connections/manages/connect_config_db.py
@@ -4,10 +4,13 @@ from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
+
class ConnectConfigEntity(Base):
- __tablename__ = 'connect_config'
- id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
- db_type = Column(String(255), nullable=False, comment="db type")
+ __tablename__ = "connect_config"
+ id = Column(
+ Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
+ )
+ db_type = Column(String(255), nullable=False, comment="db type")
db_name = Column(String(255), nullable=False, comment="db name")
db_path = Column(String(255), nullable=True, comment="file db path")
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
@@ -17,8 +20,8 @@ class ConnectConfigEntity(Base):
comment = Column(Text, nullable=True, comment="db comment")
__table_args__ = (
- UniqueConstraint('db_name', name="uk_db"),
- Index('idx_q_db_type', 'db_type'),
+ UniqueConstraint("db_name", name="uk_db"),
+ Index("idx_q_db_type", "db_type"),
)
@@ -43,9 +46,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
raise Exception("db_name is None")
db_connect = session.query(ConnectConfigEntity)
- db_connect = db_connect.filter(
- ConnectConfigEntity.db_name == db_name
- )
+ db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
db_connect.delete()
session.commit()
session.close()
@@ -53,10 +54,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_by_name(self, db_name: str) -> ConnectConfigEntity:
session = self.get_session()
db_connect = session.query(ConnectConfigEntity)
- db_connect = db_connect.filter(
- ConnectConfigEntity.db_name == db_name
- )
+ db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first()
session.close()
return result
-
diff --git a/pilot/memory/__init__.py b/pilot/memory/__init__.py
index 2e8c7af1b..77b30c893 100644
--- a/pilot/memory/__init__.py
+++ b/pilot/memory/__init__.py
@@ -1 +1 @@
-from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao
\ No newline at end of file
+from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao
diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py
index 4d9291e0b..a8a09153c 100644
--- a/pilot/memory/chat_history/base.py
+++ b/pilot/memory/chat_history/base.py
@@ -5,15 +5,17 @@ from typing import List
from enum import Enum
from pilot.scene.message import OnceConversation
+
class MemoryStoreType(Enum):
- File= 'file'
- Memory = 'memory'
- DB = 'db'
- DuckDb = 'duckdb'
+ File = "file"
+ Memory = "memory"
+ DB = "db"
+ DuckDb = "duckdb"
class BaseChatHistoryMemory(ABC):
store_type: MemoryStoreType
+
def __init__(self):
self.conversations: List[OnceConversation] = []
@@ -56,4 +58,4 @@ class BaseChatHistoryMemory(ABC):
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
- pass
\ No newline at end of file
+ pass
diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py
index 6c36053dd..64d30e971 100644
--- a/pilot/memory/chat_history/chat_hisotry_factory.py
+++ b/pilot/memory/chat_history/chat_hisotry_factory.py
@@ -4,9 +4,7 @@ from pilot.configs.config import Config
CFG = Config()
-
class ChatHistory:
-
def __init__(self):
self.memory_type = MemoryStoreType.DB.value
self.mem_store_class_map = {}
@@ -14,15 +12,16 @@ class ChatHistory:
from .store_type.file_history import FileHistoryMemory
from .store_type.meta_db_history import DbHistoryMemory
from .store_type.mem_history import MemHistoryMemory
+
self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
-
def get_store_instance(self, chat_session_id):
- return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id)
-
+ return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
+ chat_session_id
+ )
def get_store_cls(self):
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
diff --git a/pilot/memory/chat_history/chat_history_db.py b/pilot/memory/chat_history/chat_history_db.py
index f49e1983c..2b1a57c28 100644
--- a/pilot/memory/chat_history/chat_history_db.py
+++ b/pilot/memory/chat_history/chat_history_db.py
@@ -4,20 +4,28 @@ from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
+
class ChatHistoryEntity(Base):
- __tablename__ = 'chat_history'
- id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
- conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id")
+ __tablename__ = "chat_history"
+ id = Column(
+ Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
+ )
+ conv_uid = Column(
+ String(255),
+ unique=False,
+ nullable=False,
+ comment="Conversation record unique id",
+ )
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
summary = Column(String(255), nullable=False, comment="Conversation record summary")
user_name = Column(String(255), nullable=True, comment="interlocutor")
messages = Column(Text, nullable=True, comment="Conversation details")
__table_args__ = (
- UniqueConstraint('conv_uid', name="uk_conversation"),
- Index('idx_q_user', 'user_name'),
- Index('idx_q_mode', 'chat_mode'),
- Index('idx_q_conv', 'summary'),
+ UniqueConstraint("conv_uid", name="uk_conversation"),
+ Index("idx_q_user", "user_name"),
+ Index("idx_q_mode", "chat_mode"),
+ Index("idx_q_conv", "summary"),
)
@@ -31,9 +39,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
if user_name:
- chat_history = chat_history.filter(
- ChatHistoryEntity.user_name == user_name
- )
+ chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
@@ -50,13 +56,11 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
finally:
session.close()
- def update_message_by_uid(self, message: str, conv_uid:str):
+ def update_message_by_uid(self, message: str, conv_uid: str):
session = self.get_session()
try:
chat_history = session.query(ChatHistoryEntity)
- chat_history = chat_history.filter(
- ChatHistoryEntity.conv_uid == conv_uid
- )
+ chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
updated = chat_history.update({ChatHistoryEntity.messages: message})
session.commit()
return updated.id
@@ -69,9 +73,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
raise Exception("conv_uid is None")
chat_history = session.query(ChatHistoryEntity)
- chat_history = chat_history.filter(
- ChatHistoryEntity.conv_uid == conv_uid
- )
+ chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
chat_history.delete()
session.commit()
session.close()
@@ -79,10 +81,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
- chat_history = chat_history.filter(
- ChatHistoryEntity.conv_uid == conv_uid
- )
+ chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
result = chat_history.first()
session.close()
return result
-
diff --git a/pilot/memory/chat_history/store_type/duckdb_history.py b/pilot/memory/chat_history/store_type/duckdb_history.py
index 97aae159b..28c92a142 100644
--- a/pilot/memory/chat_history/store_type/duckdb_history.py
+++ b/pilot/memory/chat_history/store_type/duckdb_history.py
@@ -148,7 +148,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return json.loads(context[0])
return None
-
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
if os.path.isfile(duckdb_path):
diff --git a/pilot/memory/chat_history/store_type/file_history.py b/pilot/memory/chat_history/store_type/file_history.py
index fa1143309..a4623db36 100644
--- a/pilot/memory/chat_history/store_type/file_history.py
+++ b/pilot/memory/chat_history/store_type/file_history.py
@@ -17,7 +17,7 @@ CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
- store_type: str = MemoryStoreType.File.value
+ store_type: str = MemoryStoreType.File.value
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
@@ -49,5 +49,3 @@ class FileHistoryMemory(BaseChatHistoryMemory):
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))
-
-
diff --git a/pilot/memory/chat_history/store_type/mem_history.py b/pilot/memory/chat_history/store_type/mem_history.py
index 5c3ddc217..81f1438de 100644
--- a/pilot/memory/chat_history/store_type/mem_history.py
+++ b/pilot/memory/chat_history/store_type/mem_history.py
@@ -10,7 +10,7 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
- store_type: str = MemoryStoreType.Memory.value
+ store_type: str = MemoryStoreType.Memory.value
histroies_map = FixedSizeDict(100)
diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py
index c1fc0ec5d..137d0b161 100644
--- a/pilot/memory/chat_history/store_type/meta_db_history.py
+++ b/pilot/memory/chat_history/store_type/meta_db_history.py
@@ -12,18 +12,22 @@ from pilot.scene.message import (
from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
from pilot.memory.chat_history.base import MemoryStoreType
+
CFG = Config()
logger = logging.getLogger("db_chat_history")
+
class DbHistoryMemory(BaseChatHistoryMemory):
- store_type: str = MemoryStoreType.DB.value
+ store_type: str = MemoryStoreType.DB.value
+
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.chat_history_dao = ChatHistoryDao()
def messages(self) -> List[OnceConversation]:
-
- chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
+ chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
+ self.chat_seesion_id
+ )
if chat_history:
context = chat_history.messages
if context:
@@ -31,7 +35,6 @@ class DbHistoryMemory(BaseChatHistoryMemory):
return conversations
return []
-
def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
@@ -43,10 +46,11 @@ class DbHistoryMemory(BaseChatHistoryMemory):
except Exception as e:
logger.error("init create conversation log error!" + str(e))
-
def append(self, once_message: OnceConversation) -> None:
logger.info("db history append:{}", once_message)
- chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
+ chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
+ self.chat_seesion_id
+ )
conversations: List[OnceConversation] = []
if chat_history:
context = chat_history.messages
@@ -59,7 +63,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = "default"
- chat_history.summary = once_message.get_user_conv().content
+ chat_history.summary = once_message.get_user_conv().content
conversations.append(_conversation_to_dic(once_message))
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
@@ -67,34 +71,31 @@ class DbHistoryMemory(BaseChatHistoryMemory):
self.chat_history_dao.update(chat_history)
def update(self, messages: List[OnceConversation]) -> None:
- self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id)
-
+ self.chat_history_dao.update_message_by_uid(
+ json.dumps(messages, ensure_ascii=False), self.chat_seesion_id
+ )
def delete(self) -> bool:
self.chat_history_dao.delete(self.chat_seesion_id)
-
def conv_info(self, conv_uid: str = None) -> None:
logger.info("conv_info:{}", conv_uid)
- chat_history = self.chat_history_dao.get_by_uid(conv_uid)
+ chat_history = self.chat_history_dao.get_by_uid(conv_uid)
return chat_history.__dict__
-
def get_messages(self) -> List[OnceConversation]:
logger.info("get_messages:{}", self.chat_seesion_id)
- chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
+ chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
if chat_history:
context = chat_history.messages
return json.loads(context)
return []
-
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
-
chat_history_dao = ChatHistoryDao()
history_list = chat_history_dao.list_last_20()
result = []
for history in history_list:
result.append(history.__dict__)
- return result
\ No newline at end of file
+ return result
diff --git a/pilot/meta_data/alembic/env.py b/pilot/meta_data/alembic/env.py
index e40929dad..507a27ab4 100644
--- a/pilot/meta_data/alembic/env.py
+++ b/pilot/meta_data/alembic/env.py
@@ -66,12 +66,14 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
- if engine.dialect.name == 'sqlite':
- context.configure(connection=engine.connect(), target_metadata=target_metadata, render_as_batch=True)
- else:
+ if engine.dialect.name == "sqlite":
context.configure(
- connection=connection, target_metadata=target_metadata
+ connection=engine.connect(),
+ target_metadata=target_metadata,
+ render_as_batch=True,
)
+ else:
+ context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py
index 3bb9d7b93..7751ddb5a 100644
--- a/pilot/model/cluster/controller/controller.py
+++ b/pilot/model/cluster/controller/controller.py
@@ -145,12 +145,12 @@ def initialize_controller(
controller.backend = LocalModelController()
if app:
- app.include_router(router, prefix="/api", tags=['Model'])
+ app.include_router(router, prefix="/api", tags=["Model"])
else:
import uvicorn
app = FastAPI()
- app.include_router(router, prefix="/api", tags=['Model'])
+ app.include_router(router, prefix="/api", tags=["Model"])
uvicorn.run(app, host=host, port=port, log_level="info")
diff --git a/pilot/model/proxy/llms/baichuan.py b/pilot/model/proxy/llms/baichuan.py
index 6dd5cacad..ae4f72283 100644
--- a/pilot/model/proxy/llms/baichuan.py
+++ b/pilot/model/proxy/llms/baichuan.py
@@ -9,18 +9,21 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
+
def _calculate_md5(text: str) -> str:
- """Calculate md5 """
+ """Calculate md5"""
md5 = hashlib.md5()
md5.update(text.encode("utf-8"))
encrypted = md5.hexdigest()
return encrypted
+
def _sign(data: dict, secret_key: str, timestamp: str):
data_str = json.dumps(data)
signature = _calculate_md5(secret_key + data_str + timestamp)
return signature
+
def baichuan_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
@@ -28,12 +31,11 @@ def baichuan_generate_stream(
url = "https://api.baichuan-ai.com/v1/stream/chat"
model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL
- proxy_api_key = model_params.proxy_api_key
- proxy_api_secret = model_params.proxy_api_secret
-
+ proxy_api_key = model_params.proxy_api_key
+ proxy_api_secret = model_params.proxy_api_secret
history = []
- messages: List[ModelMessage] = params["messages"]
+ messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
@@ -47,23 +49,23 @@ def baichuan_generate_stream(
payload = {
"model": model_name,
- "messages": history,
+ "messages": history,
"parameters": {
"temperature": params.get("temperature"),
- "top_k": params.get("top_k", 10)
- }
+ "top_k": params.get("top_k", 10),
+ },
}
timestamp = int(time.time())
_signature = _sign(payload, proxy_api_secret, str(timestamp))
-
+
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + proxy_api_key,
"X-BC-Request-Id": params.get("request_id") or "dbgpt",
"X-BC-Timestamp": str(timestamp),
"X-BC-Signature": _signature,
- "X-BC-Sign-Algo": "MD5",
+ "X-BC-Sign-Algo": "MD5",
}
res = requests.post(url=url, json=payload, headers=headers, stream=True)
diff --git a/pilot/model/proxy/llms/spark.py b/pilot/model/proxy/llms/spark.py
index 2a6a1579a..72a9ccd2f 100644
--- a/pilot/model/proxy/llms/spark.py
+++ b/pilot/model/proxy/llms/spark.py
@@ -3,7 +3,7 @@ import json
import base64
import hmac
import hashlib
-import websockets
+import websockets
from datetime import datetime
from typing import List
from time import mktime
@@ -15,25 +15,25 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v2"
+
def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
model_params = model.get_params()
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
- proxy_api_key = model_params.proxy_api_key
- proxy_api_secret = model_params.proxy_api_secret
- proxy_app_id = model_params.proxy_app_id
+ proxy_api_key = model_params.proxy_api_key
+ proxy_api_secret = model_params.proxy_api_secret
+ proxy_app_id = model_params.proxy_app_id
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v2.1/chat"
- domain = "generalv2"
+ domain = "generalv2"
else:
domain = "general"
url = "ws://spark-api.xf-yun.com/v1.1/chat"
-
- messages: List[ModelMessage] = params["messages"]
-
+ messages: List[ModelMessage] = params["messages"]
+
history = []
# Add history conversation
for message in messages:
@@ -45,7 +45,7 @@ def spark_generate_stream(
history.append({"role": "assistant", "content": message.content})
else:
pass
-
+
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
@@ -55,78 +55,73 @@ def spark_generate_stream(
if m["role"] == "user":
last_user_input = m
break
-
+
data = {
- "header": {
- "app_id": proxy_app_id,
- "uid": params.get("request_id", 1)
- },
+ "header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": context_len,
"auditing": "default",
- "temperature": params.get("temperature")
+ "temperature": params.get("temperature"),
}
},
- "payload": {
- "message": {
- "text": last_user_input.get("content")
- }
- }
+ "payload": {"message": {"text": last_user_input.get("content")}},
}
async_call(request_url, data)
+
async def async_call(request_url, data):
async with websockets.connect(request_url) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
- chunk = ws.recv()
+ chunk = ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
- yield text[0]["content"]
-
+ yield text[0]["content"]
+
+
class SparkAPI:
-
- def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None:
+ def __init__(
+ self, appid: str, api_key: str, api_secret: str, spark_url: str
+ ) -> None:
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
-
+
self.spark_url = spark_url
-
def gen_url(self):
-
now = datetime.now()
- date = format_date_time(mktime(now.timetuple()))
+ date = format_date_time(mktime(now.timetuple()))
_signature = "host: " + self.host + "\n"
_signature += "data: " + date + "\n"
_signature += "GET " + self.path + " HTTP/1.1"
- _signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"),
- digestmod=hashlib.sha256).digest()
+ _signature_sha = hmac.new(
+ self.api_secret.encode("utf-8"),
+ _signature.encode("utf-8"),
+ digestmod=hashlib.sha256,
+ ).digest()
- _signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8")
+ _signature_sha_base64 = base64.b64encode(_signature_sha).decode(
+ encoding="utf-8"
+ )
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
- authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8')
-
- v = {
- "authorization": authorization,
- "date": date,
- "host": self.host
- }
+ authorization = base64.b64encode(_authorization.encode("utf-8")).decode(
+ encoding="utf-8"
+ )
+
+ v = {"authorization": authorization, "date": date, "host": self.host}
url = self.spark_url + "?" + urlencode(v)
- return url
-
-
\ No newline at end of file
+ return url
diff --git a/pilot/model/proxy/llms/tongyi.py b/pilot/model/proxy/llms/tongyi.py
index f1a95928f..fb826e49c 100644
--- a/pilot/model/proxy/llms/tongyi.py
+++ b/pilot/model/proxy/llms/tongyi.py
@@ -8,10 +8,11 @@ logger = logging.getLogger(__name__)
def tongyi_generate_stream(
- model: ProxyModel, tokenizer, params, device, context_len=2048
+ model: ProxyModel, tokenizer, params, device, context_len=2048
):
import dashscope
from dashscope import Generation
+
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
@@ -62,14 +63,14 @@ def tongyi_generate_stream(
messages=history,
top_p=params.get("top_p", 0.8),
stream=True,
- result_format='message'
+ result_format="message",
)
for r in res:
if r:
- if r['status_code'] == 200:
+ if r["status_code"] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield content
else:
- content = r['code'] + ":" + r["message"]
+ content = r["code"] + ":" + r["message"]
yield content
diff --git a/pilot/model/proxy/llms/wenxin.py b/pilot/model/proxy/llms/wenxin.py
index 262528939..acc82907c 100644
--- a/pilot/model/proxy/llms/wenxin.py
+++ b/pilot/model/proxy/llms/wenxin.py
@@ -6,20 +6,26 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from cachetools import cached, TTLCache
+
@cached(TTLCache(1, 1800))
def _build_access_token(api_key: str, secret_key: str) -> str:
"""
- Generate Access token according AK, SK
+ Generate Access token according AK, SK
"""
-
+
url = "https://aip.baidubce.com/oauth/2.0/token"
- params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
+ params = {
+ "grant_type": "client_credentials",
+ "client_id": api_key,
+ "client_secret": secret_key,
+ }
res = requests.get(url=url, params=params)
-
+
if res.status_code == 200:
return res.json().get("access_token")
+
def wenxin_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -28,23 +34,20 @@ def wenxin_generate_stream(
"ERNIE-Bot-turbo": "eb-instant",
}
- model_params = model.get_params()
- model_name = model_params.proxyllm_backend
+ model_params = model.get_params()
+ model_name = model_params.proxyllm_backend
model_version = MODEL_VERSION.get(model_name)
if not model_version:
yield f"Unsupport model version {model_name}"
-
- proxy_api_key = model_params.proxy_api_key
- proxy_api_secret = model_params.proxy_api_secret
- access_token = _build_access_token(proxy_api_key, proxy_api_secret)
-
- headers = {
- "Content-Type": "application/json",
- "Accept": "application/json"
- }
- proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}'
-
+ proxy_api_key = model_params.proxy_api_key
+ proxy_api_secret = model_params.proxy_api_secret
+ access_token = _build_access_token(proxy_api_key, proxy_api_secret)
+
+ headers = {"Content-Type": "application/json", "Accept": "application/json"}
+
+ proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}"
+
if not access_token:
yield "Failed to get access token. please set the correct api_key and secret key."
@@ -86,7 +89,7 @@ def wenxin_generate_stream(
"messages": history,
"system": system,
"temperature": params.get("temperature"),
- "stream": True
+ "stream": True,
}
text = ""
@@ -106,6 +109,3 @@ def wenxin_generate_stream(
content = obj["result"]
text += content
yield text
-
-
-
\ No newline at end of file
diff --git a/pilot/model/proxy/llms/zhipu.py b/pilot/model/proxy/llms/zhipu.py
index 3a2a1edbb..89e7dd9a0 100644
--- a/pilot/model/proxy/llms/zhipu.py
+++ b/pilot/model/proxy/llms/zhipu.py
@@ -7,6 +7,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
+
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -16,9 +17,10 @@ def zhipu_generate_stream(
# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key
- proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
+ proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
import zhipuai
+
zhipuai.api_key = proxy_api_key
history = []
@@ -63,4 +65,4 @@ def zhipu_generate_stream(
)
for r in res.events():
if r.event == "add":
- yield r.data
\ No newline at end of file
+ yield r.data
diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py
index 100b40320..ea569b3e4 100644
--- a/pilot/openapi/api_v1/api_v1.py
+++ b/pilot/openapi/api_v1/api_v1.py
@@ -297,7 +297,6 @@ async def params_load(
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
-
history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(con_uid)
history_mem.delete()
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index be0231525..d1e6ad8ac 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -54,7 +54,7 @@ class BaseChat(ABC):
)
chat_history_fac = ChatHistory()
### can configurable storage methods
- self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
+ self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(
diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py
index bb9872308..4734c8106 100644
--- a/pilot/scene/chat_agent/chat.py
+++ b/pilot/scene/chat_agent/chat.py
@@ -20,10 +20,11 @@ logger = logging.getLogger("chat_agent")
class ChatAgent(BaseChat):
chat_scene: str = ChatScene.ChatAgent.value()
chat_retention_rounds = 0
+
def __init__(self, chat_param: Dict):
- if not chat_param['select_param']:
+ if not chat_param["select_param"]:
raise ValueError("Please select a Plugin!")
- self.select_plugins = chat_param['select_param'].split(",")
+ self.select_plugins = chat_param["select_param"].split(",")
chat_param["chat_mode"] = ChatScene.ChatAgent
super().__init__(chat_param=chat_param)
@@ -31,8 +32,12 @@ class ChatAgent(BaseChat):
self.plugins_prompt_generator.command_registry = CFG.command_registry
# load select plugin
- agent_module = CFG.SYSTEM_APP.get_component(ComponentType.AGENT_HUB, ModuleAgent)
- self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
+ agent_module = CFG.SYSTEM_APP.get_component(
+ ComponentType.AGENT_HUB, ModuleAgent
+ )
+ self.plugins_prompt_generator = agent_module.load_select_plugin(
+ self.plugins_prompt_generator, self.select_plugins
+ )
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
@@ -53,4 +58,3 @@ class ChatAgent(BaseChat):
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
-
diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py
index 1c5229f89..a42fd6363 100644
--- a/pilot/scene/chat_agent/prompt.py
+++ b/pilot/scene/chat_agent/prompt.py
@@ -54,7 +54,7 @@ _DEFAULT_TEMPLATE = (
)
-_PROMPT_SCENE_DEFINE=(
+_PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
@@ -76,7 +76,7 @@ prompt = PromptTemplate(
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
- temperature = 1
+ temperature=1
# example_selector=plugin_example,
)
diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
index 4611aa14d..9599c1402 100644
--- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
+++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
@@ -36,7 +36,7 @@ class ChatExcel(BaseChat):
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
- self.api_call = ApiCall(display_registry = CFG.command_disply)
+ self.api_call = ApiCall(display_registry=CFG.command_disply)
super().__init__(chat_param=chat_param)
def _generate_numbered_list(self) -> str:
@@ -79,4 +79,4 @@ class ChatExcel(BaseChat):
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
- return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
\ No newline at end of file
+ return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
index fad30bb56..ee82b51a0 100644
--- a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
+++ b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
@@ -44,7 +44,7 @@ _DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
-PROMPT_SCENE_DEFINE =(
+PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py
index 0661737f0..c9e4aa785 100644
--- a/pilot/scene/chat_data/chat_excel/excel_reader.py
+++ b/pilot/scene/chat_data/chat_excel/excel_reader.py
@@ -9,8 +9,21 @@ import pandas as pd
import chardet
import pandas as pd
import numpy as np
-from pyparsing import CaselessKeyword, Word, alphas, alphanums, delimitedList, Forward, Group, Optional,\
- Literal, infixNotation, opAssoc, unicodeString,Regex
+from pyparsing import (
+ CaselessKeyword,
+ Word,
+ alphas,
+ alphanums,
+ delimitedList,
+ Forward,
+ Group,
+ Optional,
+ Literal,
+ infixNotation,
+ opAssoc,
+ unicodeString,
+ Regex,
+)
from pilot.common.pd_utils import csv_colunm_foramt
from pilot.common.string_utils import is_chinese_include_number
@@ -21,14 +34,15 @@ def excel_colunm_format(old_name: str) -> str:
new_column = new_column.replace(" ", "_")
return new_column
+
def detect_encoding(file_path):
# 读取文件的二进制数据
- with open(file_path, 'rb') as f:
+ with open(file_path, "rb") as f:
data = f.read()
# 使用 chardet 来检测文件编码
result = chardet.detect(data)
- encoding = result['encoding']
- confidence = result['confidence']
+ encoding = result["encoding"]
+ confidence = result["confidence"]
return encoding, confidence
@@ -39,10 +53,11 @@ def add_quotes_ex(sql: str, column_names):
sql = sql.replace(column_name, f'"{column_name}"')
return sql
+
def parse_sql(sql):
# 定义关键字和标识符
select_stmt = Forward()
- column = Regex(r'[\w一-龥]*')
+ column = Regex(r"[\w一-龥]*")
table = Word(alphanums)
join_expr = Forward()
where_expr = Forward()
@@ -62,21 +77,24 @@ def parse_sql(sql):
not_in_keyword = CaselessKeyword("NOT IN")
# 定义语法规则
- select_stmt <<= (select_keyword + delimitedList(column) +
- from_keyword + delimitedList(table) +
- Optional(join_expr) +
- Optional(where_keyword + where_expr) +
- Optional(group_by_keyword + group_by_expr) +
- Optional(order_by_keyword + order_by_expr))
+ select_stmt <<= (
+ select_keyword
+ + delimitedList(column)
+ + from_keyword
+ + delimitedList(table)
+ + Optional(join_expr)
+ + Optional(where_keyword + where_expr)
+ + Optional(group_by_keyword + group_by_expr)
+ + Optional(order_by_keyword + order_by_expr)
+ )
join_expr <<= join_keyword + table + on_keyword + column + Literal("=") + column
- where_expr <<= column + Literal("=") + Word(alphanums) + \
- Optional(and_keyword + where_expr) | \
- column + Literal(">") + Word(alphanums) + \
- Optional(and_keyword + where_expr) | \
- column + Literal("<") + Word(alphanums) + \
- Optional(and_keyword + where_expr)
+ where_expr <<= (
+ column + Literal("=") + Word(alphanums) + Optional(and_keyword + where_expr)
+ | column + Literal(">") + Word(alphanums) + Optional(and_keyword + where_expr)
+ | column + Literal("<") + Word(alphanums) + Optional(and_keyword + where_expr)
+ )
group_by_expr <<= delimitedList(column)
@@ -88,7 +106,6 @@ def parse_sql(sql):
return parsed_result.asList()
-
def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "")
sql = sql.replace("'", "")
@@ -108,6 +125,7 @@ def deep_quotes(token, column_names=[]):
new_value = token.value.replace("`", "").replace("'", "")
token.value = f'"{new_value}"'
+
def get_select_clause(sql):
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
@@ -123,6 +141,7 @@ def get_select_clause(sql):
select_tokens.append(token)
return "".join(str(token) for token in select_tokens)
+
def parse_select_fields(sql):
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
fields = []
@@ -139,12 +158,14 @@ def parse_select_fields(sql):
return fields
+
def add_quotes_to_chinese_columns(sql, column_names=[]):
parsed = sqlparse.parse(sql)
for stmt in parsed:
process_statement(stmt, column_names)
return str(parsed[0])
+
def process_statement(statement, column_names=[]):
if isinstance(statement, sqlparse.sql.IdentifierList):
for identifier in statement.get_identifiers():
@@ -155,22 +176,23 @@ def process_statement(statement, column_names=[]):
for item in statement.tokens:
process_statement(item)
+
def process_identifier(identifier, column_names=[]):
# if identifier.has_alias():
# alias = identifier.get_alias()
# identifier.tokens[-1].value = '[' + alias + ']'
- if hasattr(identifier, 'tokens') and identifier.value in column_names:
- if is_chinese(identifier.value):
+ if hasattr(identifier, "tokens") and identifier.value in column_names:
+ if is_chinese(identifier.value):
new_value = get_new_value(identifier.value)
identifier.value = new_value
identifier.normalized = new_value
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
else:
- if hasattr(identifier, 'tokens'):
+ if hasattr(identifier, "tokens"):
for token in identifier.tokens:
if isinstance(token, sqlparse.sql.Function):
process_function(token)
- elif token.ttype in sqlparse.tokens.Name :
+ elif token.ttype in sqlparse.tokens.Name:
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
@@ -179,9 +201,12 @@ def process_identifier(identifier, column_names=[]):
token.value = new_value
token.normalized = new_value
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
+
+
def get_new_value(value):
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
+
def process_function(function):
function_params = list(function.get_parameters())
# for param in function_params:
@@ -191,15 +216,18 @@ def process_function(function):
if isinstance(param, sqlparse.sql.Identifier):
# 判断是否需要替换字段值
# if is_chinese(param.value):
- # 替换字段值
+ # 替换字段值
new_value = get_new_value(param.value)
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
- function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
+ function_params[i].tokens = [
+ sqlparse.sql.Token(sqlparse.tokens.Name, new_value)
+ ]
print(str(function))
+
def is_chinese(text):
for char in text:
- if '一' <= char <= '鿿':
+ if "一" <= char <= "鿿":
return True
return False
@@ -240,7 +268,7 @@ class ExcelReader:
df_tmp = pd.read_csv(file_path, encoding=encoding)
self.df = pd.read_csv(
file_path,
- encoding = encoding,
+ encoding=encoding,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else:
@@ -280,7 +308,6 @@ class ExcelReader:
logging.error("excel sql run error!", e)
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
-
def get_df_by_sql_ex(self, sql):
colunms, values = self.run(sql)
return pd.DataFrame(values, columns=colunms)
diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py
index 2e3b759ee..e4c5175a6 100644
--- a/pilot/scene/chat_execution/chat.py
+++ b/pilot/scene/chat_execution/chat.py
@@ -16,7 +16,7 @@ class ChatWithPlugin(BaseChat):
select_plugin: str = None
def __init__(self, chat_param: Dict):
- self.plugin_selector = chat_param["select_param"]
+ self.plugin_selector = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
diff --git a/pilot/server/base.py b/pilot/server/base.py
index 3cc9b16df..42b1f6a33 100644
--- a/pilot/server/base.py
+++ b/pilot/server/base.py
@@ -15,7 +15,6 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
-
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
@@ -32,7 +31,6 @@ def async_db_summary(system_app: SystemApp):
def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
-
# logger.info(f"args: {args}")
# init config
@@ -44,8 +42,6 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
-
-
# Loader plugins and commands
command_categories = [
"pilot.base_modules.agent.commands.built_in.audio_text",
@@ -126,4 +122,3 @@ class WebWerverParameters(BaseParameters):
},
)
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
-
diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py
index 06f2c4a16..d700de94c 100644
--- a/pilot/server/component_configs.py
+++ b/pilot/server/component_configs.py
@@ -29,6 +29,7 @@ def initialize_components(
system_app.register_instance(controller)
from pilot.base_modules.agent.controller import module_agent
+
system_app.register_instance(module_agent)
_initialize_embedding_model(
diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py
index 71ef9f8d8..635ae99cf 100644
--- a/pilot/server/dbgpt_server.py
+++ b/pilot/server/dbgpt_server.py
@@ -31,7 +31,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
-from pilot.base_modules.agent.commands.disply_type.show_chart_gen import static_message_img_path
+from pilot.base_modules.agent.commands.disply_type.show_chart_gen import (
+ static_message_img_path,
+)
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import (
setup_logging,
@@ -56,6 +58,8 @@ def swagger_monkey_patch(*args, **kwargs):
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
)
+
+
app = FastAPI()
applications.get_swagger_ui_html = swagger_monkey_patch
@@ -73,14 +77,14 @@ app.add_middleware(
)
-app.include_router(api_v1, prefix="/api", tags=["Chat"])
-app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
+app.include_router(api_v1, prefix="/api", tags=["Chat"])
+app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
-app.include_router(knowledge_router, tags=["Knowledge"])
-app.include_router(prompt_router, tags=["Prompt"])
+app.include_router(knowledge_router, tags=["Knowledge"])
+app.include_router(prompt_router, tags=["Prompt"])
def mount_static_files(app):
@@ -98,6 +102,7 @@ def mount_static_files(app):
app.add_exception_handler(RequestValidationError, validation_exception_handler)
+
def _get_webserver_params(args: List[str] = None):
from pilot.utils.parameter_utils import EnvArgumentParser
@@ -106,6 +111,7 @@ def _get_webserver_params(args: List[str] = None):
)
return WebWerverParameters(**vars(parser.parse_args(args=args)))
+
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
"""Initialize app
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
diff --git a/pilot/vector_store/__init__.py b/pilot/vector_store/__init__.py
index daca3b81c..ff7e70dbc 100644
--- a/pilot/vector_store/__init__.py
+++ b/pilot/vector_store/__init__.py
@@ -1,21 +1,30 @@
from typing import Any
+
def _import_pgvector() -> Any:
- from pilot.vector_store.pgvector_store import PGVectorStore
+ from pilot.vector_store.pgvector_store import PGVectorStore
+
return PGVectorStore
+
def _import_milvus() -> Any:
from pilot.vector_store.milvus_store import MilvusStore
+
return MilvusStore
+
def _import_chroma() -> Any:
from pilot.vector_store.chroma_store import ChromaStore
+
return ChromaStore
+
def _import_weaviate() -> Any:
from pilot.vector_store.weaviate_store import WeaviateStore
+
return WeaviateStore
+
def __getattr__(name: str) -> Any:
if name == "Chroma":
return _import_chroma()
@@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any:
else:
raise AttributeError(f"Could not find: {name}")
-__all__ = [
- "Chroma",
- "Milvus",
- "Weaviate",
- "PGVector"
-]
\ No newline at end of file
+
+__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
diff --git a/pilot/vector_store/base.py b/pilot/vector_store/base.py
index 7eac8aa25..eb746c7a8 100644
--- a/pilot/vector_store/base.py
+++ b/pilot/vector_store/base.py
@@ -17,7 +17,7 @@ class VectorStoreBase(ABC):
@abstractmethod
def vector_name_exists(self) -> bool:
"""is vector store name exist."""
- return False
+ return False
@abstractmethod
def delete_by_ids(self, ids):
diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py
index efc248aba..fd2198c0f 100644
--- a/pilot/vector_store/connector.py
+++ b/pilot/vector_store/connector.py
@@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase
connector = {}
+
class VectorStoreConnector:
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
@@ -16,16 +17,15 @@ class VectorStoreConnector:
"""initialize vector store connector."""
self.ctx = ctx
self._register()
-
+
if self._match(vector_store_type):
self.connector_class = connector.get(vector_store_type)
else:
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
-
- print(self.connector_class)
+
+ print(self.connector_class)
self.client = self.connector_class(ctx)
-
def load_document(self, docs):
"""load document in vector database."""
return self.client.load_document(docs)
@@ -51,9 +51,9 @@ class VectorStoreConnector:
return True
else:
return False
-
+
def _register(self):
for cls in vector_store.__all__:
if issubclass(getattr(vector_store, cls), VectorStoreBase):
_k, _v = cls, getattr(vector_store, cls)
- connector.update({_k: _v})
\ No newline at end of file
+ connector.update({_k: _v})
diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py
index 5deca8b47..ee304fe25 100644
--- a/pilot/vector_store/milvus_store.py
+++ b/pilot/vector_store/milvus_store.py
@@ -3,7 +3,6 @@ import logging
import os
from typing import Any, Iterable, List, Optional, Tuple
-from pymilvus import Collection, DataType, connections, utility
from pilot.vector_store.base import VectorStoreBase
@@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase):
"""Milvus database"""
def __init__(self, ctx: {}) -> None:
+ from pymilvus import Collection, DataType, connections, utility
+
"""init a milvus storage connection.
Args:
@@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase):
DataType,
FieldSchema,
connections,
+ utility,
)
from pymilvus.orm.types import infer_dtype_bydata
except ImportError:
@@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase):
return doc_ids
def similar_search(self, text, topk) -> None:
+ from pymilvus import Collection, DataType
+
"""similar_search in vector database."""
self.col = Collection(self.collection_name)
schema = self.col.schema
@@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase):
return data[0], ret
def vector_name_exists(self):
+ from pymilvus import utility
+
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name):
+ from pymilvus import utility
+
"""milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(vector_name)
return True
def delete_by_ids(self, ids):
+ from pymilvus import Collection
+
self.col = Collection(self.collection_name)
"""milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...")
@@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase):
delet_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr)
return True
-
- def close(self):
- connections.disconnect()
diff --git a/pilot/vector_store/pgvector_store.py b/pilot/vector_store/pgvector_store.py
index 98ce4a027..5f6661871 100644
--- a/pilot/vector_store/pgvector_store.py
+++ b/pilot/vector_store/pgvector_store.py
@@ -7,32 +7,32 @@ logger = logging.getLogger(__name__)
CFG = Config()
+
class PGVectorStore(VectorStoreBase):
- """`Postgres.PGVector` vector store.
-
+ """`Postgres.PGVector` vector store.
+
To use this, you should have the ``pgvector`` python package installed.
"""
def __init__(self, ctx: dict) -> None:
"""init pgvector storage"""
-
+
from langchain.vectorstores import PGVector
-
+
self.ctx = ctx
self.connection_string = ctx.get("connection_string", None)
self.embeddings = ctx.get("embeddings", None)
self.collection_name = ctx.get("vector_store_name", None)
-
+
self.vector_store_client = PGVector(
embedding_function=self.embeddings,
collection_name=self.collection_name,
- connection_string=self.connection_string
+ connection_string=self.connection_string,
)
-
- def similar_search(self, text, topk, **kwargs: Any) -> None:
- return self.vector_store_client.similarity_search(text, topk)
-
+ def similar_search(self, text, topk, **kwargs: Any) -> None:
+ return self.vector_store_client.similarity_search(text, topk)
+
def vector_name_exists(self):
try:
self.vector_store_client.create_collection()
@@ -40,14 +40,12 @@ class PGVectorStore(VectorStoreBase):
except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
-
+
def load_document(self, documents) -> None:
return self.vector_store_client.from_documents(documents)
-
def delete_vector_name(self, vector_name):
- return self.vector_store_client.delete_collection()
+ return self.vector_store_client.delete_collection()
-
def delete_by_ids(self, ids):
- return self.vector_store_client.delete(ids)
\ No newline at end of file
+ return self.vector_store_client.delete(ids)
diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py
index 795cf21f9..14b90eb54 100644
--- a/pilot/vector_store/weaviate_store.py
+++ b/pilot/vector_store/weaviate_store.py
@@ -1,10 +1,6 @@
import os
-import json
import logging
-import weaviate
from langchain.schema import Document
-from langchain.vectorstores import Weaviate
-from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
@@ -72,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
- except WeaviateBaseError as e:
+ except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
diff --git a/setup.py b/setup.py
index 1876669cf..992e1cd10 100644
--- a/setup.py
+++ b/setup.py
@@ -316,6 +316,8 @@ def core_requires():
"jsonschema",
# TODO move transformers to default
"transformers>=4.31.0",
+ "GitPython",
+ "alembic",
]
@@ -404,11 +406,13 @@ def vllm_requires():
"""
setup_spec.extras["vllm"] = ["vllm"]
+
# def chat_scene():
# setup_spec.extras["chat"] = [
# ""
# ]
+
def default_requires():
"""
pip install "db-gpt[default]"
@@ -420,7 +424,7 @@ def default_requires():
"protobuf==3.20.3",
"zhipuai",
"dashscope",
- "chardet"
+ "chardet",
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
diff --git a/tests/unit_tests/vector_store/test_pgvector.py b/tests/unit_tests/vector_store/test_pgvector.py
index c96643683..59319a124 100644
--- a/tests/unit_tests/vector_store/test_pgvector.py
+++ b/tests/unit_tests/vector_store/test_pgvector.py
@@ -3,8 +3,9 @@ import pytest
from pilot import vector_store
from pilot.vector_store.base import VectorStoreBase
-def test_vetorestore_imports() -> None:
- """ Simple test to make sure all things can be imported."""
- for cls in vector_store.__all__:
+def test_vetorestore_imports() -> None:
+ """Simple test to make sure all things can be imported."""
+
+ for cls in vector_store.__all__:
assert issubclass(getattr(vector_store, cls), VectorStoreBase)