From e440daeaf231fd6887a69851607a3d7c556ba79e Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 00:03:24 +0800 Subject: [PATCH 1/6] fix:weaviate json error --- pilot/vector_store/weaviate_store.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py index 795cf21f9..14b90eb54 100644 --- a/pilot/vector_store/weaviate_store.py +++ b/pilot/vector_store/weaviate_store.py @@ -1,10 +1,6 @@ import os -import json import logging -import weaviate from langchain.schema import Document -from langchain.vectorstores import Weaviate -from weaviate.exceptions import WeaviateBaseError from pilot.configs.config import Config from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH @@ -72,7 +68,7 @@ class WeaviateStore(VectorStoreBase): if self.vector_store_client.schema.get(self.vector_name): return True return False - except WeaviateBaseError as e: + except Exception as e: logger.error("vector_name_exists error", e.message) return False From 02ab630a76c2863f0f55400e3f08c8c9dfd7a1ec Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 00:13:06 +0800 Subject: [PATCH 2/6] doc:readme --- README.zh.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/README.zh.md b/README.zh.md index 3cfaa6ed5..f1a73f063 100644 --- a/README.zh.md +++ b/README.zh.md @@ -258,7 +258,17 @@ The MIT License (MIT) - [ ] Code - [ ] Images - [x] RAG -- [ ] KnownledgeGraph +- [ ] Graph Database + - [ ] Neo4j Graph + - [ ] Nebula Graph +- [x] Multi Vector Database + - [x] Chroma + - [x] Milvus + - [x] Weaviate + - [x] PGVector + - [ ] Elasticsearch + - [ ] ClickHouse + - [ ] Faiss ### 多数据源支持 @@ -286,6 +296,7 @@ The MIT License (MIT) ### 多模型管理与推理优化 - [x] [集群部署](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) - [x] [fastchat支持](https://github.com/lm-sys/FastChat) +- [x] [fastchat支持](https://github.com/lm-sys/FastChat) - [x] [vLLM 支持](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html) ### Agents与插件市场 From c5b8aeedf5737276aeff880d257b83e22ca4dda4 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 00:13:18 +0800 Subject: [PATCH 3/6] doc:readme --- README.md | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 88b488aea..06d9e8022 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ The core capabilities mainly consist of the following parts: ![macOS](https://img.shields.io/badge/mac%20os-000000?style=for-the-badge&logo=macos&logoColor=F0F0F0) ![Windows](https://img.shields.io/badge/Windows-0078D6?style=for-the-badge&logo=windows&logoColor=white) -[**Quickstart**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html) +[**Installation && Usage Tutorial**](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html) ### Language Switching In the .env configuration file, modify the LANGUAGE parameter to switch to different languages. The default is English (Chinese: zh, English: en, other languages to be added later). @@ -214,8 +214,17 @@ The core capabilities mainly consist of the following parts: - [ ] Images - [x] RAG -- [ ] KnownledgeGraph - +- [ ] Graph Database + - [ ] Neo4j Graph + - [ ] Nebula Graph +- [x] Multi Vector Database + - [x] Chroma + - [x] Milvus + - [x] Weaviate + - [x] PGVector + - [ ] Elasticsearch + - [ ] ClickHouse + - [ ] Faiss ### Multi Datasource Support - Multi Datasource Support @@ -239,9 +248,9 @@ The core capabilities mainly consist of the following parts: - [ ] StarRocks ### Multi-Models And vLLM -- [x] [cluster deployment](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) -- [x] [fastchat support](https://github.com/lm-sys/FastChat) -- [x] [vLLM support](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html) +- [x] [Cluster Deployment](https://db-gpt.readthedocs.io/en/latest/getting_started/install/cluster/vms/index.html) +- [x] [Fastchat Support](https://github.com/lm-sys/FastChat) +- [x] [vLLM Support](https://db-gpt.readthedocs.io/en/latest/getting_started/install/llm/vllm/vllm.html) ### Agents market and Plugins - [x] multi-agents framework From 496696537d2a73d4076ca4560a434a10bca532ce Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 11:52:45 +0800 Subject: [PATCH 4/6] fix:vectordb lazy load --- pilot/vector_store/__init__.py | 19 ++++++++----- pilot/vector_store/base.py | 2 +- pilot/vector_store/connector.py | 12 ++++---- pilot/vector_store/milvus_store.py | 15 +++++++--- pilot/vector_store/pgvector_store.py | 28 +++++++++---------- .../unit_tests/vector_store/test_pgvector.py | 7 +++-- 6 files changed, 47 insertions(+), 36 deletions(-) diff --git a/pilot/vector_store/__init__.py b/pilot/vector_store/__init__.py index daca3b81c..ff7e70dbc 100644 --- a/pilot/vector_store/__init__.py +++ b/pilot/vector_store/__init__.py @@ -1,21 +1,30 @@ from typing import Any + def _import_pgvector() -> Any: - from pilot.vector_store.pgvector_store import PGVectorStore + from pilot.vector_store.pgvector_store import PGVectorStore + return PGVectorStore + def _import_milvus() -> Any: from pilot.vector_store.milvus_store import MilvusStore + return MilvusStore + def _import_chroma() -> Any: from pilot.vector_store.chroma_store import ChromaStore + return ChromaStore + def _import_weaviate() -> Any: from pilot.vector_store.weaviate_store import WeaviateStore + return WeaviateStore + def __getattr__(name: str) -> Any: if name == "Chroma": return _import_chroma() @@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any: else: raise AttributeError(f"Could not find: {name}") -__all__ = [ - "Chroma", - "Milvus", - "Weaviate", - "PGVector" -] \ No newline at end of file + +__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"] diff --git a/pilot/vector_store/base.py b/pilot/vector_store/base.py index 7eac8aa25..eb746c7a8 100644 --- a/pilot/vector_store/base.py +++ b/pilot/vector_store/base.py @@ -17,7 +17,7 @@ class VectorStoreBase(ABC): @abstractmethod def vector_name_exists(self) -> bool: """is vector store name exist.""" - return False + return False @abstractmethod def delete_by_ids(self, ids): diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index efc248aba..fd2198c0f 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase connector = {} + class VectorStoreConnector: """VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1. 1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate) @@ -16,16 +17,15 @@ class VectorStoreConnector: """initialize vector store connector.""" self.ctx = ctx self._register() - + if self._match(vector_store_type): self.connector_class = connector.get(vector_store_type) else: raise Exception(f"Vector Type Not support. {0}", vector_store_type) - - print(self.connector_class) + + print(self.connector_class) self.client = self.connector_class(ctx) - def load_document(self, docs): """load document in vector database.""" return self.client.load_document(docs) @@ -51,9 +51,9 @@ class VectorStoreConnector: return True else: return False - + def _register(self): for cls in vector_store.__all__: if issubclass(getattr(vector_store, cls), VectorStoreBase): _k, _v = cls, getattr(vector_store, cls) - connector.update({_k: _v}) \ No newline at end of file + connector.update({_k: _v}) diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 5deca8b47..ee304fe25 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -3,7 +3,6 @@ import logging import os from typing import Any, Iterable, List, Optional, Tuple -from pymilvus import Collection, DataType, connections, utility from pilot.vector_store.base import VectorStoreBase @@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase): """Milvus database""" def __init__(self, ctx: {}) -> None: + from pymilvus import Collection, DataType, connections, utility + """init a milvus storage connection. Args: @@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase): DataType, FieldSchema, connections, + utility, ) from pymilvus.orm.types import infer_dtype_bydata except ImportError: @@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase): return doc_ids def similar_search(self, text, topk) -> None: + from pymilvus import Collection, DataType + """similar_search in vector database.""" self.col = Collection(self.collection_name) schema = self.col.schema @@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase): return data[0], ret def vector_name_exists(self): + from pymilvus import utility + """is vector store name exist.""" return utility.has_collection(self.collection_name) def delete_vector_name(self, vector_name): + from pymilvus import utility + """milvus delete collection name""" logger.info(f"milvus vector_name:{vector_name} begin delete...") utility.drop_collection(vector_name) return True def delete_by_ids(self, ids): + from pymilvus import Collection + self.col = Collection(self.collection_name) """milvus delete vectors by ids""" logger.info(f"begin delete milvus ids...") @@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase): delet_expr = f"{self.primary_field} in {doc_ids}" self.col.delete(delet_expr) return True - - def close(self): - connections.disconnect() diff --git a/pilot/vector_store/pgvector_store.py b/pilot/vector_store/pgvector_store.py index 98ce4a027..5f6661871 100644 --- a/pilot/vector_store/pgvector_store.py +++ b/pilot/vector_store/pgvector_store.py @@ -7,32 +7,32 @@ logger = logging.getLogger(__name__) CFG = Config() + class PGVectorStore(VectorStoreBase): - """`Postgres.PGVector` vector store. - + """`Postgres.PGVector` vector store. + To use this, you should have the ``pgvector`` python package installed. """ def __init__(self, ctx: dict) -> None: """init pgvector storage""" - + from langchain.vectorstores import PGVector - + self.ctx = ctx self.connection_string = ctx.get("connection_string", None) self.embeddings = ctx.get("embeddings", None) self.collection_name = ctx.get("vector_store_name", None) - + self.vector_store_client = PGVector( embedding_function=self.embeddings, collection_name=self.collection_name, - connection_string=self.connection_string + connection_string=self.connection_string, ) - - def similar_search(self, text, topk, **kwargs: Any) -> None: - return self.vector_store_client.similarity_search(text, topk) - + def similar_search(self, text, topk, **kwargs: Any) -> None: + return self.vector_store_client.similarity_search(text, topk) + def vector_name_exists(self): try: self.vector_store_client.create_collection() @@ -40,14 +40,12 @@ class PGVectorStore(VectorStoreBase): except Exception as e: logger.error("vector_name_exists error", e.message) return False - + def load_document(self, documents) -> None: return self.vector_store_client.from_documents(documents) - def delete_vector_name(self, vector_name): - return self.vector_store_client.delete_collection() + return self.vector_store_client.delete_collection() - def delete_by_ids(self, ids): - return self.vector_store_client.delete(ids) \ No newline at end of file + return self.vector_store_client.delete(ids) diff --git a/tests/unit_tests/vector_store/test_pgvector.py b/tests/unit_tests/vector_store/test_pgvector.py index c96643683..59319a124 100644 --- a/tests/unit_tests/vector_store/test_pgvector.py +++ b/tests/unit_tests/vector_store/test_pgvector.py @@ -3,8 +3,9 @@ import pytest from pilot import vector_store from pilot.vector_store.base import VectorStoreBase -def test_vetorestore_imports() -> None: - """ Simple test to make sure all things can be imported.""" - for cls in vector_store.__all__: +def test_vetorestore_imports() -> None: + """Simple test to make sure all things can be imported.""" + + for cls in vector_store.__all__: assert issubclass(getattr(vector_store, cls), VectorStoreBase) From da87e401634d87fed32a8224ba5d1d48a34ed42b Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 13:52:13 +0800 Subject: [PATCH 5/6] chore:add GitPython && alembic requirement --- setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3785a277d..32ba11d73 100644 --- a/setup.py +++ b/setup.py @@ -315,6 +315,8 @@ def core_requires(): "jsonschema", # TODO move transformers to default "transformers>=4.31.0", + "GitPython", + "alembic", ] @@ -403,11 +405,13 @@ def vllm_requires(): """ setup_spec.extras["vllm"] = ["vllm"] + # def chat_scene(): # setup_spec.extras["chat"] = [ # "" # ] + def default_requires(): """ pip install "db-gpt[default]" @@ -419,7 +423,7 @@ def default_requires(): "protobuf==3.20.3", "zhipuai", "dashscope", - "chardet" + "chardet", ] setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["knowledge"] From f65ca37a0226ae7b941b811f1f9ae4bc34117214 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 13:55:19 +0800 Subject: [PATCH 6/6] style:fmt --- examples/proxy_example.py | 40 +++--- pilot/base_modules/agent/__init__.py | 4 +- pilot/base_modules/agent/commands/command.py | 8 +- .../agent/commands/command_mange.py | 129 ++++++++++++------ .../commands/disply_type/show_chart_gen.py | 22 ++- .../commands/disply_type/show_table_gen.py | 2 + .../commands/disply_type/show_text_gen.py | 3 +- .../base_modules/agent/commands/generator.py | 2 - pilot/base_modules/agent/common/schema.py | 9 +- pilot/base_modules/agent/controller.py | 46 ++++--- pilot/base_modules/agent/db/my_plugin_db.py | 85 ++++-------- pilot/base_modules/agent/db/plugin_hub_db.py | 48 +++---- pilot/base_modules/agent/hub/agent_hub.py | 51 ++++--- pilot/base_modules/agent/model.py | 38 ++++-- pilot/base_modules/agent/plugins_util.py | 37 +++-- pilot/base_modules/mange_base_api.py | 3 +- pilot/base_modules/meta_data/base_dao.py | 9 +- pilot/base_modules/meta_data/meta_data.py | 58 ++++---- pilot/base_modules/module_factory.py | 1 - pilot/common/string_utils.py | 32 +++-- pilot/configs/config.py | 17 +-- pilot/connections/__init__.py | 2 +- .../connections/manages/connect_config_db.py | 22 ++- pilot/memory/__init__.py | 2 +- pilot/memory/chat_history/base.py | 12 +- .../chat_history/chat_hisotry_factory.py | 9 +- pilot/memory/chat_history/chat_history_db.py | 41 +++--- .../chat_history/store_type/duckdb_history.py | 1 - .../chat_history/store_type/file_history.py | 4 +- .../chat_history/store_type/mem_history.py | 2 +- .../store_type/meta_db_history.py | 33 ++--- pilot/meta_data/alembic/env.py | 10 +- pilot/model/cluster/controller/controller.py | 4 +- pilot/model/proxy/llms/baichuan.py | 22 +-- pilot/model/proxy/llms/spark.py | 79 +++++------ pilot/model/proxy/llms/tongyi.py | 9 +- pilot/model/proxy/llms/wenxin.py | 42 +++--- pilot/model/proxy/llms/zhipu.py | 6 +- pilot/openapi/api_v1/api_v1.py | 1 - pilot/scene/base_chat.py | 2 +- pilot/scene/chat_agent/chat.py | 14 +- pilot/scene/chat_agent/prompt.py | 4 +- .../chat_excel/excel_analyze/chat.py | 4 +- .../chat_excel/excel_learning/prompt.py | 2 +- .../chat_data/chat_excel/excel_reader.py | 83 +++++++---- pilot/scene/chat_execution/chat.py | 2 +- pilot/server/base.py | 5 - pilot/server/component_configs.py | 1 + pilot/server/dbgpt_server.py | 16 ++- 49 files changed, 582 insertions(+), 496 deletions(-) diff --git a/examples/proxy_example.py b/examples/proxy_example.py index 5d2f8e5db..a3d2f3bc4 100644 --- a/examples/proxy_example.py +++ b/examples/proxy_example.py @@ -7,55 +7,61 @@ import hashlib from http import HTTPStatus from dashscope import Generation + def call_with_messages(): - messages = [{'role': 'system', 'content': '你是生活助手机器人。'}, - {'role': 'user', 'content': '如何做西红柿鸡蛋?'}] + messages = [ + {"role": "system", "content": "你是生活助手机器人。"}, + {"role": "user", "content": "如何做西红柿鸡蛋?"}, + ] gen = Generation() response = gen.call( Generation.Models.qwen_turbo, messages=messages, stream=True, top_p=0.8, - result_format='message', # set the result to be "message" format. + result_format="message", # set the result to be "message" format. ) for response in response: # The response status_code is HTTPStatus.OK indicate success, # otherwise indicate request is failed, you can get error code # and message from code and message. - if response.status_code == HTTPStatus.OK: - print(response.output) # The output text + if response.status_code == HTTPStatus.OK: + print(response.output) # The output text print(response.usage) # The usage information else: - print(response.code) # The error code. - print(response.message) # The error message. - + print(response.code) # The error code. + print(response.message) # The error message. def build_access_token(api_key: str, secret_key: str) -> str: """ - Generate Access token according AK, SK + Generate Access token according AK, SK """ - + url = "https://aip.baidubce.com/oauth/2.0/token" - params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + params = { + "grant_type": "client_credentials", + "client_id": api_key, + "client_secret": secret_key, + } res = requests.get(url=url, params=params) - + if res.status_code == 200: return res.json().get("access_token") def _calculate_md5(text: str) -> str: - md5 = hashlib.md5() md5.update(text.encode("utf-8")) encrypted = md5.hexdigest() return encrypted + def baichuan_call(): url = "https://api.baichuan-ai.com/v1/stream/chat" - - -if __name__ == '__main__': - call_with_messages() \ No newline at end of file + + +if __name__ == "__main__": + call_with_messages() diff --git a/pilot/base_modules/agent/__init__.py b/pilot/base_modules/agent/__init__.py index e7bbb7b7b..60f0489da 100644 --- a/pilot/base_modules/agent/__init__.py +++ b/pilot/base_modules/agent/__init__.py @@ -1,4 +1,4 @@ -from .db.my_plugin_db import MyPluginEntity, MyPluginDao +from .db.my_plugin_db import MyPluginEntity, MyPluginDao from .db.plugin_hub_db import PluginHubEntity, PluginHubDao from .commands.command import execute_command, get_command @@ -8,4 +8,4 @@ from .commands.disply_type.show_chart_gen import static_message_img_path from .common.schema import Status, PluginStorageType from .commands.command_mange import ApiCall -from .commands.command import execute_command \ No newline at end of file +from .commands.command import execute_command diff --git a/pilot/base_modules/agent/commands/command.py b/pilot/base_modules/agent/commands/command.py index a3202f67d..bd5806ec0 100644 --- a/pilot/base_modules/agent/commands/command.py +++ b/pilot/base_modules/agent/commands/command.py @@ -9,6 +9,7 @@ from .generator import PluginPromptGenerator from pilot.configs.config import Config + def _resolve_pathlike_command_args(command_args): if "directory" in command_args and command_args["directory"] in {"", "/"}: # todo @@ -64,8 +65,6 @@ def execute_ai_response_json( return result - - def execute_command( command_name: str, arguments, @@ -81,10 +80,8 @@ def execute_command( str: The result of the command """ - cmd = plugin_generator.command_registry.commands.get(command_name) - # If the command is found, call it with the provided arguments if cmd: try: @@ -153,6 +150,3 @@ def get_command(response_json: Dict): # All other errors, return "Error: + error message" except Exception as e: return "Error:", str(e) - - - diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py index 7b8f8cc06..be9e02811 100644 --- a/pilot/base_modules/agent/commands/command_mange.py +++ b/pilot/base_modules/agent/commands/command_mange.py @@ -28,13 +28,13 @@ class Command: """ def __init__( - self, - name: str, - description: str, - method: Callable[..., Any], - signature: str = "", - enabled: bool = True, - disabled_reason: Optional[str] = None, + self, + name: str, + description: str, + method: Callable[..., Any], + signature: str = "", + enabled: bool = True, + disabled_reason: Optional[str] = None, ): self.name = name self.description = description @@ -87,11 +87,12 @@ class CommandRegistry: if hasattr(reloaded_module, "register"): reloaded_module.register(self) - def is_valid_command(self, name:str)-> bool: + def is_valid_command(self, name: str) -> bool: if name not in self.commands: return False else: return True + def get_command(self, name: str) -> Callable[..., Any]: return self.commands[name] @@ -129,23 +130,23 @@ class CommandRegistry: attr = getattr(module, attr_name) # Register decorated functions if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr( - attr, AUTO_GPT_COMMAND_IDENTIFIER + attr, AUTO_GPT_COMMAND_IDENTIFIER ): self.register(attr.command) # Register command classes elif ( - inspect.isclass(attr) and issubclass(attr, Command) and attr != Command + inspect.isclass(attr) and issubclass(attr, Command) and attr != Command ): cmd_instance = attr() self.register(cmd_instance) def command( - name: str, - description: str, - signature: str = "", - enabled: bool = True, - disabled_reason: Optional[str] = None, + name: str, + description: str, + signature: str = "", + enabled: bool = True, + disabled_reason: Optional[str] = None, ) -> Callable[..., Any]: """The command decorator is used to create Command objects from ordinary functions.""" @@ -241,34 +242,60 @@ class ApiCall: else: md_tag_end = "```" for tag in error_md_tags: - all_context = all_context.replace(tag + api_context + md_tag_end, api_context) - all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context) - all_context = all_context.replace(tag + " " +api_context + " " + md_tag_end, api_context) + all_context = all_context.replace( + tag + api_context + md_tag_end, api_context + ) + all_context = all_context.replace( + tag + "\n" + api_context + "\n" + md_tag_end, api_context + ) + all_context = all_context.replace( + tag + " " + api_context + " " + md_tag_end, api_context + ) all_context = all_context.replace(tag + api_context, api_context) return all_context def api_view_context(self, all_context: str, display_mode: bool = False): error_mk_tags = ["```", "```python", "```xml"] - call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True) + call_context_map = extract_content_open_ending( + all_context, self.agent_prefix, self.agent_end, True + ) for api_index, api_context in call_context_map.items(): api_status = self.plugin_status_map.get(api_context) if api_status is not None: if display_mode: if api_status.api_result: - all_context = self.__deal_error_md_tags(all_context, api_context) - all_context = all_context.replace(api_context, api_status.api_result) + all_context = self.__deal_error_md_tags( + all_context, api_context + ) + all_context = all_context.replace( + api_context, api_status.api_result + ) else: if api_status.status == Status.FAILED.value: - all_context = self.__deal_error_md_tags(all_context, api_context) - all_context = all_context.replace(api_context, f"""\nERROR!{api_status.err_msg}\n """) + all_context = self.__deal_error_md_tags( + all_context, api_context + ) + all_context = all_context.replace( + api_context, + f"""\nERROR!{api_status.err_msg}\n """, + ) else: cost = (api_status.end_time - self.start_time) / 1000 cost_str = "{:.2f}".format(cost) - all_context = self.__deal_error_md_tags(all_context, api_context) - all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n') + all_context = self.__deal_error_md_tags( + all_context, api_context + ) + all_context = all_context.replace( + api_context, + f'\nWaiting...{cost_str}S\n', + ) else: - all_context = self.__deal_error_md_tags(all_context, api_context, False) - all_context = all_context.replace(api_context, self.to_view_text(api_status)) + all_context = self.__deal_error_md_tags( + all_context, api_context, False + ) + all_context = all_context.replace( + api_context, self.to_view_text(api_status) + ) else: # not ready api call view change @@ -276,27 +303,34 @@ class ApiCall: cost = (now_time - self.start_time) / 1000 cost_str = "{:.2f}".format(cost) for tag in error_mk_tags: - all_context = all_context.replace(tag + api_context , api_context) - all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n') + all_context = all_context.replace(tag + api_context, api_context) + all_context = all_context.replace( + api_context, + f'\nWaiting...{cost_str}S\n', + ) return all_context def update_from_context(self, all_context): - api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True) + api_context_map = extract_content( + all_context, self.agent_prefix, self.agent_end, True + ) for api_index, api_context in api_context_map.items(): api_context = api_context.replace("\\n", "").replace("\n", "") api_call_element = ET.fromstring(api_context) - api_name = api_call_element.find('name').text - if api_name.find("[")>=0 or api_name.find("]")>=0: + api_name = api_call_element.find("name").text + if api_name.find("[") >= 0 or api_name.find("]") >= 0: api_name = api_name.replace("[", "").replace("]", "") api_args = {} - args_elements = api_call_element.find('args') + args_elements = api_call_element.find("args") for child_element in args_elements.iter(): api_args[child_element.tag] = child_element.text api_status = self.plugin_status_map.get(api_context) if api_status is None: - api_status = PluginStatus(name=api_name, location=[api_index], args=api_args) + api_status = PluginStatus( + name=api_name, location=[api_index], args=api_args + ) self.plugin_status_map[api_context] = api_status else: api_status.location.append(api_index) @@ -304,20 +338,20 @@ class ApiCall: def __to_view_param_str(self, api_status): param = {} if api_status.name: - param['name'] = api_status.name - param['status'] = api_status.status + param["name"] = api_status.name + param["status"] = api_status.status if api_status.logo_url: - param['logo'] = api_status.logo_url + param["logo"] = api_status.logo_url if api_status.err_msg: - param['err_msg'] = api_status.err_msg + param["err_msg"] = api_status.err_msg if api_status.api_result: - param['result'] = api_status.api_result + param["result"] = api_status.api_result return json.dumps(param) def to_view_text(self, api_status: PluginStatus): - api_call_element = ET.Element('dbgpt-view') + api_call_element = ET.Element("dbgpt-view") api_call_element.text = self.__to_view_param_str(api_status) result = ET.tostring(api_call_element, encoding="utf-8") return result.decode("utf-8") @@ -332,7 +366,9 @@ class ApiCall: value.status = Status.RUNNING.value logging.info(f"插件执行:{value.name},{value.args}") try: - value.api_result = execute_command(value.name, value.args, self.plugin_generator) + value.api_result = execute_command( + value.name, value.args, self.plugin_generator + ) value.status = Status.COMPLETED.value except Exception as e: value.status = Status.FAILED.value @@ -350,15 +386,19 @@ class ApiCall: value.status = Status.RUNNING.value logging.info(f"sql展示执行:{value.name},{value.args}") try: - sql = value.args['sql'] + sql = value.args["sql"] if sql: param = { "df": sql_run_func(sql), } if self.display_registry.is_valid_command(value.name): - value.api_result = self.display_registry.call(value.name, **param) + value.api_result = self.display_registry.call( + value.name, **param + ) else: - value.api_result = self.display_registry.call("response_table", **param) + value.api_result = self.display_registry.call( + "response_table", **param + ) value.status = Status.COMPLETED.value except Exception as e: @@ -366,4 +406,3 @@ class ApiCall: value.err_msg = str(e) value.end_time = datetime.now().timestamp() * 1000 return self.api_view_context(llm_text, True) - diff --git a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py index 7807f3818..166992822 100644 --- a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py +++ b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py @@ -15,6 +15,7 @@ from matplotlib.font_manager import FontManager from pilot.common.string_utils import is_scientific_notation import logging + logger = logging.getLogger(__name__) @@ -88,12 +89,13 @@ def zh_font_set(): if len(can_use_fonts) > 0: plt.rcParams["font.sans-serif"] = can_use_fonts + def format_axis(value, pos): # 判断是否为数字 if is_scientific_notation(value): # 判断是否需要进行非科学计数法格式化 - return '{:.2f}'.format(value) + return "{:.2f}".format(value) return value @@ -102,7 +104,7 @@ def format_axis(value, pos): "Line chart display, used to display comparative trend analysis data", '"df":""', ) -def response_line_chart( df: DataFrame) -> str: +def response_line_chart(df: DataFrame) -> str: logger.info(f"response_line_chart") if df.size <= 0: raise ValueError("No Data!") @@ -143,9 +145,15 @@ def response_line_chart( df: DataFrame) -> str: if len(num_colmns) > 0: num_colmns.append(y) df_melted = pd.melt( - df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" + df, + id_vars=x, + value_vars=num_colmns, + var_name="line", + value_name="Value", + ) + sns.lineplot( + data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2" ) - sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") else: sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") @@ -154,7 +162,7 @@ def response_line_chart( df: DataFrame) -> str: chart_name = "line_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name - plt.savefig(chart_path, dpi=100, transparent=True) + plt.savefig(chart_path, dpi=100, transparent=True) html_img = f"""""" return html_img @@ -168,7 +176,7 @@ def response_line_chart( df: DataFrame) -> str: "Histogram, suitable for comparative analysis of multiple target values", '"df":""', ) -def response_bar_chart( df: DataFrame) -> str: +def response_bar_chart(df: DataFrame) -> str: logger.info(f"response_bar_chart") if df.size <= 0: raise ValueError("No Data!") @@ -246,7 +254,7 @@ def response_bar_chart( df: DataFrame) -> str: chart_name = "bar_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name - plt.savefig(chart_path, dpi=100,transparent=True) + plt.savefig(chart_path, dpi=100, transparent=True) html_img = f"""""" return html_img diff --git a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py index 9afd14ca5..d11a00c7a 100644 --- a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py +++ b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py @@ -3,8 +3,10 @@ from pandas import DataFrame from pilot.base_modules.agent.commands.command_mange import command import logging + logger = logging.getLogger(__name__) + @command( "response_table", "Table display, suitable for display with many display columns or non-numeric columns", diff --git a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py index 75400fd97..b58ca5843 100644 --- a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py +++ b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py @@ -3,6 +3,7 @@ from pandas import DataFrame from pilot.base_modules.agent.commands.command_mange import command import logging + logger = logging.getLogger(__name__) @@ -22,7 +23,7 @@ def response_data_text(df: DataFrame) -> str: html_table = df.to_html(index=False, escape=False, sparsify=False) table_str = "".join(html_table.split()) html = f"""
{table_str}
""" - text_info = html.replace("\n", " ") + text_info = html.replace("\n", " ") elif row_size == 1: row = data[0] for value in row: diff --git a/pilot/base_modules/agent/commands/generator.py b/pilot/base_modules/agent/commands/generator.py index 310551481..d1dded6fe 100644 --- a/pilot/base_modules/agent/commands/generator.py +++ b/pilot/base_modules/agent/commands/generator.py @@ -133,5 +133,3 @@ class PluginPromptGenerator: def generate_commands_string(self) -> str: return f"{self._generate_numbered_list(self.commands, item_type='command')}" - - diff --git a/pilot/base_modules/agent/common/schema.py b/pilot/base_modules/agent/common/schema.py index 5c36e4da6..87ba196b0 100644 --- a/pilot/base_modules/agent/common/schema.py +++ b/pilot/base_modules/agent/common/schema.py @@ -5,13 +5,14 @@ class PluginStorageType(Enum): Git = "git" Oss = "oss" + class Status(Enum): TODO = "todo" - RUNNING = 'running' - FAILED = 'failed' - COMPLETED = 'completed' + RUNNING = "running" + FAILED = "failed" + COMPLETED = "completed" class ApiTagType(Enum): API_VIEW = "dbgpt_view" - API_CALL = "dbgpt_call" \ No newline at end of file + API_CALL = "dbgpt_call" diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py index 4bca3dee8..47532a66b 100644 --- a/pilot/base_modules/agent/controller.py +++ b/pilot/base_modules/agent/controller.py @@ -16,7 +16,13 @@ from pilot.openapi.api_view_model import ( Result, ) -from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter +from .model import ( + PluginHubParam, + PagenationFilter, + PagenationResult, + PluginHubFilter, + MyPluginFilter, +) from .hub.agent_hub import AgentHub from .db.plugin_hub_db import PluginHubEntity from .plugins_util import scan_plugins @@ -33,17 +39,18 @@ class ModuleAgent(BaseComponent, ABC): name = ComponentType.AGENT_HUB def __init__(self): - #load plugins + # load plugins self.plugins = scan_plugins(PLUGINS_DIR) def init_app(self, system_app: SystemApp): system_app.app.include_router(router, prefix="/api", tags=["Agent"]) - def refresh_plugins(self): self.plugins = scan_plugins(PLUGINS_DIR) - def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator: + def load_select_plugin( + self, generator: PluginPromptGenerator, select_plugins: List[str] + ) -> PluginPromptGenerator: logger.info(f"load_select_plugin:{select_plugins}") # load select plugin for plugin in self.plugins: @@ -53,6 +60,7 @@ class ModuleAgent(BaseComponent, ABC): generator = plugin.post_prompt(generator) return generator + module_agent = ModuleAgent() @@ -61,25 +69,28 @@ async def agent_hub_update(update_param: PluginHubParam = Body()): logger.info(f"agent_hub_update:{update_param.__dict__}") try: agent_hub = AgentHub(PLUGINS_DIR) - agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization) + agent_hub.refresh_hub_from_git( + update_param.url, update_param.branch, update_param.authorization + ) return Result.succ(None) except Exception as e: logger.error("Agent Hub Update Error!", e) return Result.faild(code="E0020", msg=f"Agent Hub Update Error! {e}") - @router.post("/v1/agent/query", response_model=Result[str]) async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()): logger.info(f"get_agent_list:{filter.__dict__}") agent_hub = AgentHub(PLUGINS_DIR) - filter_enetity:PluginHubEntity = PluginHubEntity() + filter_enetity: PluginHubEntity = PluginHubEntity() if filter.filter: attrs = vars(filter.filter) # 获取原始对象的属性字典 for attr, value in attrs.items(): setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值 - datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size) + datas, total_pages, total_count = agent_hub.hub_dao.list( + filter_enetity, filter.page_index, filter.page_size + ) result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]() result.page_index = filter.page_index result.page_size = filter.page_size @@ -89,11 +100,12 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()): # print(json.dumps(result.to_dic())) return Result.succ(result.to_dic()) + @router.post("/v1/agent/my", response_model=Result[str]) -async def my_agents(user:str= None): +async def my_agents(user: str = None): logger.info(f"my_agents:{user}") agent_hub = AgentHub(PLUGINS_DIR) - agents = agent_hub.get_my_plugin(user) + agents = agent_hub.get_my_plugin(user) agent_dicts = [] for agent in agents: agent_dicts.append(agent.__dict__) @@ -102,7 +114,7 @@ async def my_agents(user:str= None): @router.post("/v1/agent/install", response_model=Result[str]) -async def agent_install(plugin_name:str, user: str = None): +async def agent_install(plugin_name: str, user: str = None): logger.info(f"agent_install:{plugin_name},{user}") try: agent_hub = AgentHub(PLUGINS_DIR) @@ -111,14 +123,13 @@ async def agent_install(plugin_name:str, user: str = None): module_agent.refresh_plugins() return Result.succ(None) - except Exception as e: + except Exception as e: logger.error("Plugin Install Error!", e) return Result.faild(code="E0021", msg=f"Plugin Install Error {e}") - @router.post("/v1/agent/uninstall", response_model=Result[str]) -async def agent_uninstall(plugin_name:str, user: str = None): +async def agent_uninstall(plugin_name: str, user: str = None): logger.info(f"agent_uninstall:{plugin_name},{user}") try: agent_hub = AgentHub(PLUGINS_DIR) @@ -126,19 +137,18 @@ async def agent_uninstall(plugin_name:str, user: str = None): module_agent.refresh_plugins() return Result.succ(None) - except Exception as e: + except Exception as e: logger.error("Plugin Uninstall Error!", e) return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}") @router.post("/v1/personal/agent/upload", response_model=Result[str]) -async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None): +async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None): logger.info(f"personal_agent_upload:{doc_file.filename},{user}") try: agent_hub = AgentHub(PLUGINS_DIR) await agent_hub.upload_my_plugin(doc_file, user) return Result.succ(None) - except Exception as e: + except Exception as e: logger.error("Upload Personal Plugin Error!", e) return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}") - diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py index 8661d6f70..c49734900 100644 --- a/pilot/base_modules/agent/db/my_plugin_db.py +++ b/pilot/base_modules/agent/db/my_plugin_db.py @@ -9,30 +9,33 @@ from pilot.base_modules.meta_data.base_dao import BaseDao from pilot.base_modules.meta_data.meta_data import Base, engine, session - class MyPluginEntity(Base): - __tablename__ = 'my_plugin' + __tablename__ = "my_plugin" id = Column(Integer, primary_key=True, comment="autoincrement id") tenant = Column(String(255), nullable=True, comment="user's tenant") user_code = Column(String(255), nullable=False, comment="user code") user_name = Column(String(255), nullable=True, comment="user name") name = Column(String(255), unique=True, nullable=False, comment="plugin name") - file_name = Column(String(255), nullable=False, comment="plugin package file name") - type = Column(String(255), comment="plugin type") - version = Column(String(255), comment="plugin version") - use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") - succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count") - created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time") - __table_args__ = ( - UniqueConstraint('user_code','name', name="uk_name"), + file_name = Column(String(255), nullable=False, comment="plugin package file name") + type = Column(String(255), comment="plugin type") + version = Column(String(255), comment="plugin version") + use_count = Column( + Integer, nullable=True, default=0, comment="plugin total use count" ) + succ_count = Column( + Integer, nullable=True, default=0, comment="plugin total success count" + ) + created_at = Column( + DateTime, default=datetime.utcnow, comment="plugin install time" + ) + __table_args__ = (UniqueConstraint("user_code", "name", name="uk_name"),) class MyPluginDao(BaseDao[MyPluginEntity]): def __init__(self): super().__init__( - database="dbgpt", orm_base=Base, db_engine =engine , session= session + database="dbgpt", orm_base=Base, db_engine=engine, session=session ) def add(self, engity: MyPluginEntity): @@ -60,87 +63,61 @@ class MyPluginDao(BaseDao[MyPluginEntity]): session.commit() return updated.id - def get_by_user(self, user: str)->list[MyPluginEntity]: + def get_by_user(self, user: str) -> list[MyPluginEntity]: session = self.get_session() my_plugins = session.query(MyPluginEntity) - if user: - my_plugins = my_plugins.filter( - MyPluginEntity.user_code == user - ) + if user: + my_plugins = my_plugins.filter(MyPluginEntity.user_code == user) result = my_plugins.all() session.close() return result - - def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]: + def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]: session = self.get_session() my_plugins = session.query(MyPluginEntity) all_count = my_plugins.count() if query.id is not None: my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) if query.name is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.name == query.name - ) + my_plugins = my_plugins.filter(MyPluginEntity.name == query.name) if query.tenant is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.tenant == query.tenant - ) + my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant) if query.type is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.type == query.type - ) + my_plugins = my_plugins.filter(MyPluginEntity.type == query.type) if query.user_code is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.user_code == query.user_code - ) + my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code) if query.user_name is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.user_name == query.user_name - ) + my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name) my_plugins = my_plugins.order_by(MyPluginEntity.id.desc()) - my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size) + my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size) result = my_plugins.all() session.close() total_pages = all_count // page_size if all_count % page_size != 0: total_pages += 1 - return result, total_pages, all_count - def count(self, query: MyPluginEntity): session = self.get_session() my_plugins = session.query(func.count(MyPluginEntity.id)) if query.id is not None: my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) if query.name is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.name == query.name - ) + my_plugins = my_plugins.filter(MyPluginEntity.name == query.name) if query.type is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.type == query.type - ) + my_plugins = my_plugins.filter(MyPluginEntity.type == query.type) if query.tenant is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.tenant == query.tenant - ) + my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant) if query.user_code is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.user_code == query.user_code - ) + my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code) if query.user_name is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.user_name == query.user_name - ) + my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name) count = my_plugins.scalar() session.close() return count - def delete(self, plugin_id: int): session = self.get_session() if plugin_id is None: @@ -148,9 +125,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]): query = MyPluginEntity(id=plugin_id) my_plugins = session.query(MyPluginEntity) if query.id is not None: - my_plugins = my_plugins.filter( - MyPluginEntity.id == query.id - ) + my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins.delete() session.commit() session.close() diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py index f5620fc79..8507bcc8e 100644 --- a/pilot/base_modules/agent/db/plugin_hub_db.py +++ b/pilot/base_modules/agent/db/plugin_hub_db.py @@ -10,8 +10,10 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session class PluginHubEntity(Base): - __tablename__ = 'plugin_hub' - id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + __tablename__ = "plugin_hub" + id = Column( + Integer, primary_key=True, autoincrement=True, comment="autoincrement id" + ) name = Column(String(255), unique=True, nullable=False, comment="plugin name") description = Column(String(255), nullable=False, comment="plugin description") author = Column(String(255), nullable=True, comment="plugin author") @@ -25,8 +27,8 @@ class PluginHubEntity(Base): installed = Column(Integer, default=False, comment="plugin already installed count") __table_args__ = ( - UniqueConstraint('name', name="uk_name"), - Index('idx_q_type', 'type'), + UniqueConstraint("name", name="uk_name"), + Index("idx_q_type", "type"), ) @@ -38,7 +40,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]): def add(self, engity: PluginHubEntity): session = self.get_session() - timezone = pytz.timezone('Asia/Shanghai') + timezone = pytz.timezone("Asia/Shanghai") plugin_hub = PluginHubEntity( name=engity.name, author=engity.author, @@ -64,7 +66,9 @@ class PluginHubDao(BaseDao[PluginHubEntity]): finally: session.close() - def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]: + def list( + self, query: PluginHubEntity, page=1, page_size=20 + ) -> list[PluginHubEntity]: session = self.get_session() plugin_hubs = session.query(PluginHubEntity) all_count = plugin_hubs.count() @@ -72,17 +76,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]): if query.id is not None: plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) if query.name is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.name == query.name - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name) if query.type is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.type == query.type - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type) if query.author is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.author == query.author - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author) if query.storage_channel is not None: plugin_hubs = plugin_hubs.filter( PluginHubEntity.storage_channel == query.storage_channel @@ -110,9 +108,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]): def get_by_name(self, name: str) -> PluginHubEntity: session = self.get_session() plugin_hubs = session.query(PluginHubEntity) - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.name == name - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name) result = plugin_hubs.first() session.close() return result @@ -123,17 +119,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]): if query.id is not None: plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) if query.name is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.name == query.name - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name) if query.type is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.type == query.type - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type) if query.author is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.author == query.author - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author) if query.storage_channel is not None: plugin_hubs = plugin_hubs.filter( PluginHubEntity.storage_channel == query.storage_channel @@ -148,9 +138,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]): raise Exception("plugin_id is None") plugin_hubs = session.query(PluginHubEntity) if plugin_id is not None: - plugin_hubs = plugin_hubs.filter( - PluginHubEntity.id == plugin_id - ) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id) plugin_hubs.delete() session.commit() session.close() diff --git a/pilot/base_modules/agent/hub/agent_hub.py b/pilot/base_modules/agent/hub/agent_hub.py index b61158a40..608ab45d8 100644 --- a/pilot/base_modules/agent/hub/agent_hub.py +++ b/pilot/base_modules/agent/hub/agent_hub.py @@ -4,7 +4,7 @@ import os import glob import shutil from fastapi import UploadFile -from typing import Any +from typing import Any import tempfile from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao @@ -38,7 +38,9 @@ class AgentHub: download_param = json.loads(plugin_entity.download_param) branch_name = download_param.get("branch_name") authorization = download_param.get("authorization") - file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization) + file_name = self.__download_from_git( + plugin_entity.storage_url, branch_name, authorization + ) # add to my plugins and edit hub status plugin_entity.installed = plugin_entity.installed + 1 @@ -65,7 +67,9 @@ class AgentHub: logger.error("install pluguin exception!", e) raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}") else: - raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!") + raise ValueError( + f"Unsupport Storage Channel {plugin_entity.storage_channel}!" + ) else: raise ValueError(f"Can't Find Plugin {plugin_name}!") @@ -75,7 +79,9 @@ class AgentHub: plugin_entity.installed = plugin_entity.installed - 1 with self.hub_dao.get_session() as session: try: - my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name) + my_plugin_q = session.query(MyPluginEntity).filter( + MyPluginEntity.name == plugin_name + ) if user: my_plugin_q.filter(MyPluginEntity.user_code == user) my_plugin_q.delete() @@ -92,10 +98,10 @@ class AgentHub: have_installed = True break if not have_installed: - plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1] - files = glob.glob( - os.path.join(self.plugin_dir, f"{plugin_repo_name}*") + plugin_repo_name = ( + plugin_entity.storage_url.replace(".git", "").strip("/").split("/")[-1] ) + files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*")) for file in files: os.remove(file) @@ -109,9 +115,16 @@ class AgentHub: my_plugin_entity.version = hub_plugin.version return my_plugin_entity - def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None): + def refresh_hub_from_git( + self, + github_repo: str = None, + branch_name: str = None, + authorization: str = None, + ): logger.info("refresh_hub_by_git start!") - update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization) + update_from_git( + self.temp_hub_file_path, github_repo, branch_name, authorization + ) git_plugins = scan_plugins(self.temp_hub_file_path) try: for git_plugin in git_plugins: @@ -123,13 +136,13 @@ class AgentHub: plugin_hub_info.type = "" plugin_hub_info.storage_channel = PluginStorageType.Git.value plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO - plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT') - plugin_hub_info.email = getattr(git_plugin, '_email', '') + plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT") + plugin_hub_info.email = getattr(git_plugin, "_email", "") download_param = {} if branch_name: - download_param['branch_name'] = branch_name + download_param["branch_name"] = branch_name if authorization and len(authorization) > 0: - download_param['authorization'] = authorization + download_param["authorization"] = authorization plugin_hub_info.download_param = json.dumps(download_param) plugin_hub_info.installed = 0 @@ -140,15 +153,12 @@ class AgentHub: except Exception as e: raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}") - async def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User): - + async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User): # We can not move temp file in windows system when we open file in context of `with` file_path = os.path.join(self.plugin_dir, doc_file.filename) if os.path.exists(file_path): os.remove(file_path) - tmp_fd, tmp_path = tempfile.mkstemp( - dir=os.path.join(self.plugin_dir) - ) + tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir)) with os.fdopen(tmp_fd, "wb") as tmp: tmp.write(await doc_file.read()) shutil.move( @@ -158,14 +168,14 @@ class AgentHub: my_plugins = scan_plugins(self.plugin_dir, doc_file.filename) - if user is None or len(user) <=0: + if user is None or len(user) <= 0: user = Default_User for my_plugin in my_plugins: my_plugin_entiy = MyPluginEntity() my_plugin_entiy.name = my_plugin._name - my_plugin_entiy.version = my_plugin._version + my_plugin_entiy.version = my_plugin._version my_plugin_entiy.type = "Personal" my_plugin_entiy.user_code = user my_plugin_entiy.user_name = user @@ -183,4 +193,3 @@ class AgentHub: if not user: user = Default_User return self.my_lugin_dao.get_by_user(user) - diff --git a/pilot/base_modules/agent/model.py b/pilot/base_modules/agent/model.py index 8c02e4af4..eed2cb420 100644 --- a/pilot/base_modules/agent/model.py +++ b/pilot/base_modules/agent/model.py @@ -3,32 +3,35 @@ from dataclasses import dataclass from pydantic import BaseModel, Field from typing import TypeVar, Generic, Any -T = TypeVar('T') +T = TypeVar("T") + class PagenationFilter(BaseModel, Generic[T]): page_index: int = 1 - page_size: int = 20 + page_size: int = 20 filter: T = None + class PagenationResult(BaseModel, Generic[T]): page_index: int = 1 - page_size: int = 20 + page_size: int = 20 total_page: int = 0 total_row_count: int = 0 datas: List[T] = [] def to_dic(self): - data_dicts =[] + data_dicts = [] for item in self.datas: data_dicts.append(item.__dict__) return { - 'page_index': self.page_index, - 'page_size': self.page_size, - 'total_page': self.total_page, - 'total_row_count': self.total_row_count, - 'datas': data_dicts + "page_index": self.page_index, + "page_size": self.page_size, + "total_page": self.total_page, + "total_row_count": self.total_row_count, + "datas": data_dicts, } + @dataclass class PluginHubFilter(BaseModel): name: str @@ -53,9 +56,14 @@ class MyPluginFilter(BaseModel): class PluginHubParam(BaseModel): - channel: Optional[str] = Field("git", description="Plugin storage channel") - url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url") - branch: Optional[str] = Field("main", description="github download branch", nullable=True) - authorization: Optional[str] = Field(None, description="github download authorization", nullable=True) - - + channel: Optional[str] = Field("git", description="Plugin storage channel") + url: Optional[str] = Field( + "https://github.com/eosphoros-ai/DB-GPT-Plugins.git", + description="Plugin storage url", + ) + branch: Optional[str] = Field( + "main", description="github download branch", nullable=True + ) + authorization: Optional[str] = Field( + None, description="github download authorization", nullable=True + ) diff --git a/pilot/base_modules/agent/plugins_util.py b/pilot/base_modules/agent/plugins_util.py index b0a9c7885..facc1472d 100644 --- a/pilot/base_modules/agent/plugins_util.py +++ b/pilot/base_modules/agent/plugins_util.py @@ -117,7 +117,7 @@ def load_native_plugins(cfg: Config): t.start() -def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]: +def __scan_plugin_file(file_path, debug: bool = False) -> List[AutoGPTPluginTemplate]: logger.info(f"__scan_plugin_file:{file_path},{debug}") loaded_plugins = [] if moduleList := inspect_zip_for_modules(str(file_path), debug): @@ -133,14 +133,17 @@ def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTempl a_module = getattr(zipped_module, key) a_keys = dir(a_module) if ( - "_abc_impl" in a_keys - and a_module.__name__ != "AutoGPTPluginTemplate" - # and denylist_allowlist_check(a_module.__name__, cfg) + "_abc_impl" in a_keys + and a_module.__name__ != "AutoGPTPluginTemplate" + # and denylist_allowlist_check(a_module.__name__, cfg) ): loaded_plugins.append(a_module()) return loaded_plugins -def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]: + +def scan_plugins( + plugins_file_path: str, file_name: str = "", debug: bool = False +) -> List[AutoGPTPluginTemplate]: """Scan the plugins directory for plugins and loads them. Args: @@ -159,7 +162,7 @@ def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = Fals loaded_plugins = __scan_plugin_file(plugin_path) else: for plugin_path in plugins_path.glob("*.zip"): - loaded_plugins.extend(__scan_plugin_file(plugin_path)) + loaded_plugins.extend(__scan_plugin_file(plugin_path)) if loaded_plugins: logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") @@ -192,17 +195,23 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool: return ack.lower() == cfg.authorise_key -def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main", - authorization: str = None): +def update_from_git( + download_path: str, + github_repo: str = "", + branch_name: str = "main", + authorization: str = None, +): os.makedirs(download_path, exist_ok=True) if github_repo: if github_repo.index("github.com") <= 0: raise ValueError("Not a correct Github repository address!" + github_repo) github_repo = github_repo.replace(".git", "") url = github_repo + "/archive/refs/heads/" + branch_name + ".zip" - plugin_repo_name = github_repo.strip('/').split('/')[-1] + plugin_repo_name = github_repo.strip("/").split("/")[-1] else: - url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip" + url = ( + "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip" + ) plugin_repo_name = "DB-GPT-Plugins" try: session = requests.Session() @@ -216,14 +225,14 @@ def update_from_git(download_path: str, github_repo: str = "", branch_name: str if response.status_code == 200: plugins_path_path = Path(download_path) - files = glob.glob( - os.path.join(plugins_path_path, f"{plugin_repo_name}*") - ) + files = glob.glob(os.path.join(plugins_path_path, f"{plugin_repo_name}*")) for file in files: os.remove(file) now = datetime.datetime.now() time_str = now.strftime("%Y%m%d%H%M%S") - file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip" + file_name = ( + f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip" + ) print(file_name) with open(file_name, "wb") as f: f.write(response.content) diff --git a/pilot/base_modules/mange_base_api.py b/pilot/base_modules/mange_base_api.py index c0b5da273..57f2a27e7 100644 --- a/pilot/base_modules/mange_base_api.py +++ b/pilot/base_modules/mange_base_api.py @@ -1,7 +1,6 @@ class ModuleMangeApi: - def module_name(self): pass def register(self): - pass \ No newline at end of file + pass diff --git a/pilot/base_modules/meta_data/base_dao.py b/pilot/base_modules/meta_data/base_dao.py index 330bed592..693fe5699 100644 --- a/pilot/base_modules/meta_data/base_dao.py +++ b/pilot/base_modules/meta_data/base_dao.py @@ -1,11 +1,16 @@ from typing import TypeVar, Generic, List, Any from sqlalchemy.orm import sessionmaker -T = TypeVar('T') +T = TypeVar("T") + class BaseDao(Generic[T]): def __init__( - self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None, + self, + orm_base=None, + database: str = None, + db_engine: Any = None, + session: Any = None, ) -> None: """BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist""" self._orm_base = orm_base diff --git a/pilot/base_modules/meta_data/meta_data.py b/pilot/base_modules/meta_data/meta_data.py index 46a5f44b0..e5551da82 100644 --- a/pilot/base_modules/meta_data/meta_data.py +++ b/pilot/base_modules/meta_data/meta_data.py @@ -7,7 +7,7 @@ import fnmatch from datetime import datetime from typing import Optional, Type, TypeVar -from sqlalchemy import create_engine,DateTime, String, func, MetaData +from sqlalchemy import create_engine, DateTime, String, func, MetaData from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped @@ -32,16 +32,17 @@ db_name = "dbgpt" db_path = default_db_path + f"/{db_name}.db" connection = sqlite3.connect(db_path) -if CFG.LOCAL_DB_TYPE == 'mysql': - engine_temp = create_engine(f"mysql+pymysql://" - + quote(CFG.LOCAL_DB_USER) - + ":" - + quote(CFG.LOCAL_DB_PASSWORD) - + "@" - + CFG.LOCAL_DB_HOST - + ":" - + str(CFG.LOCAL_DB_PORT) - ) +if CFG.LOCAL_DB_TYPE == "mysql": + engine_temp = create_engine( + f"mysql+pymysql://" + + quote(CFG.LOCAL_DB_USER) + + ":" + + quote(CFG.LOCAL_DB_PASSWORD) + + "@" + + CFG.LOCAL_DB_HOST + + ":" + + str(CFG.LOCAL_DB_PORT) + ) # check and auto create mysqldatabase try: # try to connect @@ -53,20 +54,19 @@ if CFG.LOCAL_DB_TYPE == 'mysql': # if connect failed, create dbgpt database logger.error(f"{db_name} not connect success!") - engine = create_engine(f"mysql+pymysql://" - + quote(CFG.LOCAL_DB_USER) - + ":" - + quote(CFG.LOCAL_DB_PASSWORD) - + "@" - + CFG.LOCAL_DB_HOST - + ":" - + str(CFG.LOCAL_DB_PORT) - + f"/{db_name}" - ) + engine = create_engine( + f"mysql+pymysql://" + + quote(CFG.LOCAL_DB_USER) + + ":" + + quote(CFG.LOCAL_DB_PASSWORD) + + "@" + + CFG.LOCAL_DB_HOST + + ":" + + str(CFG.LOCAL_DB_PORT) + + f"/{db_name}" + ) else: - engine = create_engine(f'sqlite:///{db_path}') - - + engine = create_engine(f"sqlite:///{db_path}") Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -81,16 +81,16 @@ Base = declarative_base() alembic_ini_path = default_db_path + "/alembic.ini" alembic_cfg = AlembicConfig(alembic_ini_path) -alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url)) +alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url)) os.makedirs(default_db_path + "/alembic", exist_ok=True) os.makedirs(default_db_path + "/alembic/versions", exist_ok=True) -alembic_cfg.set_main_option('script_location', default_db_path + "/alembic") +alembic_cfg.set_main_option("script_location", default_db_path + "/alembic") # 将模型和会话传递给Alembic配置 -alembic_cfg.attributes['target_metadata'] = Base.metadata -alembic_cfg.attributes['session'] = session +alembic_cfg.attributes["target_metadata"] = Base.metadata +alembic_cfg.attributes["session"] = session # # 创建表 @@ -106,7 +106,7 @@ def ddl_init_and_upgrade(): # command.upgrade(alembic_cfg, 'head') # subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"]) with engine.connect() as connection: - alembic_cfg.attributes['connection'] = connection + alembic_cfg.attributes["connection"] = connection heads = command.heads(alembic_cfg) print("heads:" + str(heads)) diff --git a/pilot/base_modules/module_factory.py b/pilot/base_modules/module_factory.py index 139597f9c..8b1378917 100644 --- a/pilot/base_modules/module_factory.py +++ b/pilot/base_modules/module_factory.py @@ -1,2 +1 @@ - diff --git a/pilot/common/string_utils.py b/pilot/common/string_utils.py index 14bf5082e..170f0519a 100644 --- a/pilot/common/string_utils.py +++ b/pilot/common/string_utils.py @@ -1,27 +1,30 @@ import re + def is_all_chinese(text): ### Determine whether the string is pure Chinese - pattern = re.compile(r'^[一-龥]+$') + pattern = re.compile(r"^[一-龥]+$") match = re.match(pattern, text) return match is not None def is_number_chinese(text): ### Determine whether the string is numbers and Chinese - pattern = re.compile(r'^[\d一-龥]+$') + pattern = re.compile(r"^[\d一-龥]+$") match = re.match(pattern, text) return match is not None + def is_chinese_include_number(text): ### Determine whether the string is pure Chinese or Chinese containing numbers - pattern = re.compile(r'^[一-龥]+[\d一-龥]*$') + pattern = re.compile(r"^[一-龥]+[\d一-龥]*$") match = re.match(pattern, text) return match is not None + def is_scientific_notation(string): # 科学计数法的正则表达式 - pattern = r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$' + pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$" # 使用正则表达式匹配字符串 match = re.match(pattern, str(string)) # 判断是否匹配成功 @@ -30,28 +33,30 @@ def is_scientific_notation(string): else: return False + def extract_content(long_string, s1, s2, is_include: bool = False): # extract text - match_map ={} + match_map = {} start_index = long_string.find(s1) while start_index != -1: if is_include: end_index = long_string.find(s2, start_index + len(s1) + 1) - extracted_content = long_string[start_index:end_index + len(s2)] + extracted_content = long_string[start_index : end_index + len(s2)] else: end_index = long_string.find(s2, start_index + len(s1)) - extracted_content = long_string[start_index + len(s1):end_index] + extracted_content = long_string[start_index + len(s1) : end_index] if extracted_content: match_map[start_index] = extracted_content start_index = long_string.find(s1, start_index + 1) return match_map + def extract_content_open_ending(long_string, s1, s2, is_include: bool = False): # extract text open ending match_map = {} start_index = long_string.find(s1) while start_index != -1: - if long_string.find(s2, start_index) <=0: + if long_string.find(s2, start_index) <= 0: end_index = len(long_string) else: if is_include: @@ -59,19 +64,18 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False): else: end_index = long_string.find(s2, start_index + len(s1)) if is_include: - extracted_content = long_string[start_index:end_index + len(s2)] + extracted_content = long_string[start_index : end_index + len(s2)] else: - extracted_content = long_string[start_index + len(s1):end_index] + extracted_content = long_string[start_index + len(s1) : end_index] if extracted_content: match_map[start_index] = extracted_content - start_index= long_string.find(s1, start_index + 1) + start_index = long_string.find(s1, start_index + 1) return match_map - -if __name__=="__main__": +if __name__ == "__main__": s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123" s1 = "123" s2 = "456" - print(extract_content_open_ending(s, s1, s2, True)) \ No newline at end of file + print(extract_content_open_ending(s, s1, s2, True)) diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 0db89b70b..a25462c5e 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -55,20 +55,24 @@ class Config(metaclass=Singleton): self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY") if self.tongyi_proxy_api_key: os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key - + # zhipu self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY") if self.zhipu_proxy_api_key: os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key - os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv("ZHIPU_MODEL_VERSION") + os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv( + "ZHIPU_MODEL_VERSION" + ) # wenxin self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY") - self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY") + self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY") self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION") if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret: os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key - os.environ["wenxin_proxyllm_proxy_api_secret"] = self.wenxin_proxy_api_secret + os.environ[ + "wenxin_proxyllm_proxy_api_secret" + ] = self.wenxin_proxy_api_secret os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version # xunfei spark @@ -90,8 +94,7 @@ class Config(metaclass=Singleton): os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version - - + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") @@ -172,7 +175,6 @@ class Config(metaclass=Singleton): os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true" ) - self.LOCAL_DB_MANAGE = None ###dbgpt meta info database connection configuration @@ -190,7 +192,6 @@ class Config(metaclass=Singleton): self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb") - ### LLM Model Service Configuration self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5") ### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm" diff --git a/pilot/connections/__init__.py b/pilot/connections/__init__.py index ce13a69f3..8cc9799db 100644 --- a/pilot/connections/__init__.py +++ b/pilot/connections/__init__.py @@ -1 +1 @@ -from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao \ No newline at end of file +from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao diff --git a/pilot/connections/manages/connect_config_db.py b/pilot/connections/manages/connect_config_db.py index 42307b243..7443d18ad 100644 --- a/pilot/connections/manages/connect_config_db.py +++ b/pilot/connections/manages/connect_config_db.py @@ -4,10 +4,13 @@ from typing import List from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text from sqlalchemy import UniqueConstraint + class ConnectConfigEntity(Base): - __tablename__ = 'connect_config' - id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") - db_type = Column(String(255), nullable=False, comment="db type") + __tablename__ = "connect_config" + id = Column( + Integer, primary_key=True, autoincrement=True, comment="autoincrement id" + ) + db_type = Column(String(255), nullable=False, comment="db type") db_name = Column(String(255), nullable=False, comment="db name") db_path = Column(String(255), nullable=True, comment="file db path") db_host = Column(String(255), nullable=True, comment="db connect host(not file db)") @@ -17,8 +20,8 @@ class ConnectConfigEntity(Base): comment = Column(Text, nullable=True, comment="db comment") __table_args__ = ( - UniqueConstraint('db_name', name="uk_db"), - Index('idx_q_db_type', 'db_type'), + UniqueConstraint("db_name", name="uk_db"), + Index("idx_q_db_type", "db_type"), ) @@ -43,9 +46,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]): raise Exception("db_name is None") db_connect = session.query(ConnectConfigEntity) - db_connect = db_connect.filter( - ConnectConfigEntity.db_name == db_name - ) + db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name) db_connect.delete() session.commit() session.close() @@ -53,10 +54,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]): def get_by_name(self, db_name: str) -> ConnectConfigEntity: session = self.get_session() db_connect = session.query(ConnectConfigEntity) - db_connect = db_connect.filter( - ConnectConfigEntity.db_name == db_name - ) + db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name) result = db_connect.first() session.close() return result - diff --git a/pilot/memory/__init__.py b/pilot/memory/__init__.py index 2e8c7af1b..77b30c893 100644 --- a/pilot/memory/__init__.py +++ b/pilot/memory/__init__.py @@ -1 +1 @@ -from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao \ No newline at end of file +from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index 4d9291e0b..a8a09153c 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -5,15 +5,17 @@ from typing import List from enum import Enum from pilot.scene.message import OnceConversation + class MemoryStoreType(Enum): - File= 'file' - Memory = 'memory' - DB = 'db' - DuckDb = 'duckdb' + File = "file" + Memory = "memory" + DB = "db" + DuckDb = "duckdb" class BaseChatHistoryMemory(ABC): store_type: MemoryStoreType + def __init__(self): self.conversations: List[OnceConversation] = [] @@ -56,4 +58,4 @@ class BaseChatHistoryMemory(ABC): @staticmethod def conv_list(cls, user_name: str = None) -> None: - pass \ No newline at end of file + pass diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py index 6c36053dd..64d30e971 100644 --- a/pilot/memory/chat_history/chat_hisotry_factory.py +++ b/pilot/memory/chat_history/chat_hisotry_factory.py @@ -4,9 +4,7 @@ from pilot.configs.config import Config CFG = Config() - class ChatHistory: - def __init__(self): self.memory_type = MemoryStoreType.DB.value self.mem_store_class_map = {} @@ -14,15 +12,16 @@ class ChatHistory: from .store_type.file_history import FileHistoryMemory from .store_type.meta_db_history import DbHistoryMemory from .store_type.mem_history import MemHistoryMemory + self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory - def get_store_instance(self, chat_session_id): - return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id) - + return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)( + chat_session_id + ) def get_store_cls(self): return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE) diff --git a/pilot/memory/chat_history/chat_history_db.py b/pilot/memory/chat_history/chat_history_db.py index f49e1983c..2b1a57c28 100644 --- a/pilot/memory/chat_history/chat_history_db.py +++ b/pilot/memory/chat_history/chat_history_db.py @@ -4,20 +4,28 @@ from typing import List from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text from sqlalchemy import UniqueConstraint + class ChatHistoryEntity(Base): - __tablename__ = 'chat_history' - id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") - conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id") + __tablename__ = "chat_history" + id = Column( + Integer, primary_key=True, autoincrement=True, comment="autoincrement id" + ) + conv_uid = Column( + String(255), + unique=False, + nullable=False, + comment="Conversation record unique id", + ) chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode") summary = Column(String(255), nullable=False, comment="Conversation record summary") user_name = Column(String(255), nullable=True, comment="interlocutor") messages = Column(Text, nullable=True, comment="Conversation details") __table_args__ = ( - UniqueConstraint('conv_uid', name="uk_conversation"), - Index('idx_q_user', 'user_name'), - Index('idx_q_mode', 'chat_mode'), - Index('idx_q_conv', 'summary'), + UniqueConstraint("conv_uid", name="uk_conversation"), + Index("idx_q_user", "user_name"), + Index("idx_q_mode", "chat_mode"), + Index("idx_q_conv", "summary"), ) @@ -31,9 +39,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]): session = self.get_session() chat_history = session.query(ChatHistoryEntity) if user_name: - chat_history = chat_history.filter( - ChatHistoryEntity.user_name == user_name - ) + chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name) chat_history = chat_history.order_by(ChatHistoryEntity.id.desc()) @@ -50,13 +56,11 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]): finally: session.close() - def update_message_by_uid(self, message: str, conv_uid:str): + def update_message_by_uid(self, message: str, conv_uid: str): session = self.get_session() try: chat_history = session.query(ChatHistoryEntity) - chat_history = chat_history.filter( - ChatHistoryEntity.conv_uid == conv_uid - ) + chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) updated = chat_history.update({ChatHistoryEntity.messages: message}) session.commit() return updated.id @@ -69,9 +73,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]): raise Exception("conv_uid is None") chat_history = session.query(ChatHistoryEntity) - chat_history = chat_history.filter( - ChatHistoryEntity.conv_uid == conv_uid - ) + chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) chat_history.delete() session.commit() session.close() @@ -79,10 +81,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]): def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity: session = self.get_session() chat_history = session.query(ChatHistoryEntity) - chat_history = chat_history.filter( - ChatHistoryEntity.conv_uid == conv_uid - ) + chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) result = chat_history.first() session.close() return result - diff --git a/pilot/memory/chat_history/store_type/duckdb_history.py b/pilot/memory/chat_history/store_type/duckdb_history.py index 97aae159b..28c92a142 100644 --- a/pilot/memory/chat_history/store_type/duckdb_history.py +++ b/pilot/memory/chat_history/store_type/duckdb_history.py @@ -148,7 +148,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return json.loads(context[0]) return None - @staticmethod def conv_list(cls, user_name: str = None) -> None: if os.path.isfile(duckdb_path): diff --git a/pilot/memory/chat_history/store_type/file_history.py b/pilot/memory/chat_history/store_type/file_history.py index fa1143309..a4623db36 100644 --- a/pilot/memory/chat_history/store_type/file_history.py +++ b/pilot/memory/chat_history/store_type/file_history.py @@ -17,7 +17,7 @@ CFG = Config() class FileHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.File.value + store_type: str = MemoryStoreType.File.value def __init__(self, chat_session_id: str): now = datetime.datetime.now() @@ -49,5 +49,3 @@ class FileHistoryMemory(BaseChatHistoryMemory): def clear(self) -> None: self.file_path.write_text(json.dumps([])) - - diff --git a/pilot/memory/chat_history/store_type/mem_history.py b/pilot/memory/chat_history/store_type/mem_history.py index 5c3ddc217..81f1438de 100644 --- a/pilot/memory/chat_history/store_type/mem_history.py +++ b/pilot/memory/chat_history/store_type/mem_history.py @@ -10,7 +10,7 @@ CFG = Config() class MemHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.Memory.value + store_type: str = MemoryStoreType.Memory.value histroies_map = FixedSizeDict(100) diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py index c1fc0ec5d..137d0b161 100644 --- a/pilot/memory/chat_history/store_type/meta_db_history.py +++ b/pilot/memory/chat_history/store_type/meta_db_history.py @@ -12,18 +12,22 @@ from pilot.scene.message import ( from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao from pilot.memory.chat_history.base import MemoryStoreType + CFG = Config() logger = logging.getLogger("db_chat_history") + class DbHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.DB.value + store_type: str = MemoryStoreType.DB.value + def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id self.chat_history_dao = ChatHistoryDao() def messages(self) -> List[OnceConversation]: - - chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id) + chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( + self.chat_seesion_id + ) if chat_history: context = chat_history.messages if context: @@ -31,7 +35,6 @@ class DbHistoryMemory(BaseChatHistoryMemory): return conversations return [] - def create(self, chat_mode, summary: str, user_name: str) -> None: try: chat_history: ChatHistoryEntity = ChatHistoryEntity() @@ -43,10 +46,11 @@ class DbHistoryMemory(BaseChatHistoryMemory): except Exception as e: logger.error("init create conversation log error!" + str(e)) - def append(self, once_message: OnceConversation) -> None: logger.info("db history append:{}", once_message) - chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id) + chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( + self.chat_seesion_id + ) conversations: List[OnceConversation] = [] if chat_history: context = chat_history.messages @@ -59,7 +63,7 @@ class DbHistoryMemory(BaseChatHistoryMemory): chat_history.conv_uid = self.chat_seesion_id chat_history.chat_mode = once_message.chat_mode chat_history.user_name = "default" - chat_history.summary = once_message.get_user_conv().content + chat_history.summary = once_message.get_user_conv().content conversations.append(_conversation_to_dic(once_message)) chat_history.messages = json.dumps(conversations, ensure_ascii=False) @@ -67,34 +71,31 @@ class DbHistoryMemory(BaseChatHistoryMemory): self.chat_history_dao.update(chat_history) def update(self, messages: List[OnceConversation]) -> None: - self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id) - + self.chat_history_dao.update_message_by_uid( + json.dumps(messages, ensure_ascii=False), self.chat_seesion_id + ) def delete(self) -> bool: self.chat_history_dao.delete(self.chat_seesion_id) - def conv_info(self, conv_uid: str = None) -> None: logger.info("conv_info:{}", conv_uid) - chat_history = self.chat_history_dao.get_by_uid(conv_uid) + chat_history = self.chat_history_dao.get_by_uid(conv_uid) return chat_history.__dict__ - def get_messages(self) -> List[OnceConversation]: logger.info("get_messages:{}", self.chat_seesion_id) - chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id) + chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id) if chat_history: context = chat_history.messages return json.loads(context) return [] - @staticmethod def conv_list(cls, user_name: str = None) -> None: - chat_history_dao = ChatHistoryDao() history_list = chat_history_dao.list_last_20() result = [] for history in history_list: result.append(history.__dict__) - return result \ No newline at end of file + return result diff --git a/pilot/meta_data/alembic/env.py b/pilot/meta_data/alembic/env.py index e40929dad..507a27ab4 100644 --- a/pilot/meta_data/alembic/env.py +++ b/pilot/meta_data/alembic/env.py @@ -66,12 +66,14 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - if engine.dialect.name == 'sqlite': - context.configure(connection=engine.connect(), target_metadata=target_metadata, render_as_batch=True) - else: + if engine.dialect.name == "sqlite": context.configure( - connection=connection, target_metadata=target_metadata + connection=engine.connect(), + target_metadata=target_metadata, + render_as_batch=True, ) + else: + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 3bb9d7b93..7751ddb5a 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -145,12 +145,12 @@ def initialize_controller( controller.backend = LocalModelController() if app: - app.include_router(router, prefix="/api", tags=['Model']) + app.include_router(router, prefix="/api", tags=["Model"]) else: import uvicorn app = FastAPI() - app.include_router(router, prefix="/api", tags=['Model']) + app.include_router(router, prefix="/api", tags=["Model"]) uvicorn.run(app, host=host, port=port, log_level="info") diff --git a/pilot/model/proxy/llms/baichuan.py b/pilot/model/proxy/llms/baichuan.py index 6dd5cacad..ae4f72283 100644 --- a/pilot/model/proxy/llms/baichuan.py +++ b/pilot/model/proxy/llms/baichuan.py @@ -9,18 +9,21 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B" + def _calculate_md5(text: str) -> str: - """Calculate md5 """ + """Calculate md5""" md5 = hashlib.md5() md5.update(text.encode("utf-8")) encrypted = md5.hexdigest() return encrypted + def _sign(data: dict, secret_key: str, timestamp: str): data_str = json.dumps(data) signature = _calculate_md5(secret_key + data_str + timestamp) return signature + def baichuan_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=4096 ): @@ -28,12 +31,11 @@ def baichuan_generate_stream( url = "https://api.baichuan-ai.com/v1/stream/chat" model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL - proxy_api_key = model_params.proxy_api_key - proxy_api_secret = model_params.proxy_api_secret - + proxy_api_key = model_params.proxy_api_key + proxy_api_secret = model_params.proxy_api_secret history = [] - messages: List[ModelMessage] = params["messages"] + messages: List[ModelMessage] = params["messages"] # Add history conversation for message in messages: if message.role == ModelMessageRoleType.HUMAN: @@ -47,23 +49,23 @@ def baichuan_generate_stream( payload = { "model": model_name, - "messages": history, + "messages": history, "parameters": { "temperature": params.get("temperature"), - "top_k": params.get("top_k", 10) - } + "top_k": params.get("top_k", 10), + }, } timestamp = int(time.time()) _signature = _sign(payload, proxy_api_secret, str(timestamp)) - + headers = { "Content-Type": "application/json", "Authorization": "Bearer " + proxy_api_key, "X-BC-Request-Id": params.get("request_id") or "dbgpt", "X-BC-Timestamp": str(timestamp), "X-BC-Signature": _signature, - "X-BC-Sign-Algo": "MD5", + "X-BC-Sign-Algo": "MD5", } res = requests.post(url=url, json=payload, headers=headers, stream=True) diff --git a/pilot/model/proxy/llms/spark.py b/pilot/model/proxy/llms/spark.py index 2a6a1579a..72a9ccd2f 100644 --- a/pilot/model/proxy/llms/spark.py +++ b/pilot/model/proxy/llms/spark.py @@ -3,7 +3,7 @@ import json import base64 import hmac import hashlib -import websockets +import websockets from datetime import datetime from typing import List from time import mktime @@ -15,25 +15,25 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel SPARK_DEFAULT_API_VERSION = "v2" + def spark_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): model_params = model.get_params() proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION - proxy_api_key = model_params.proxy_api_key - proxy_api_secret = model_params.proxy_api_secret - proxy_app_id = model_params.proxy_app_id + proxy_api_key = model_params.proxy_api_key + proxy_api_secret = model_params.proxy_api_secret + proxy_app_id = model_params.proxy_app_id if proxy_api_version == SPARK_DEFAULT_API_VERSION: url = "ws://spark-api.xf-yun.com/v2.1/chat" - domain = "generalv2" + domain = "generalv2" else: domain = "general" url = "ws://spark-api.xf-yun.com/v1.1/chat" - - messages: List[ModelMessage] = params["messages"] - + messages: List[ModelMessage] = params["messages"] + history = [] # Add history conversation for message in messages: @@ -45,7 +45,7 @@ def spark_generate_stream( history.append({"role": "assistant", "content": message.content}) else: pass - + spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url) request_url = spark_api.gen_url() @@ -55,78 +55,73 @@ def spark_generate_stream( if m["role"] == "user": last_user_input = m break - + data = { - "header": { - "app_id": proxy_app_id, - "uid": params.get("request_id", 1) - }, + "header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)}, "parameter": { "chat": { "domain": domain, "random_threshold": 0.5, "max_tokens": context_len, "auditing": "default", - "temperature": params.get("temperature") + "temperature": params.get("temperature"), } }, - "payload": { - "message": { - "text": last_user_input.get("content") - } - } + "payload": {"message": {"text": last_user_input.get("content")}}, } async_call(request_url, data) + async def async_call(request_url, data): async with websockets.connect(request_url) as ws: await ws.send(json.dumps(data, ensure_ascii=False)) finish = False while not finish: - chunk = ws.recv() + chunk = ws.recv() response = json.loads(chunk) if response.get("header", {}).get("status") == 2: finish = True if text := response.get("payload", {}).get("choices", {}).get("text"): - yield text[0]["content"] - + yield text[0]["content"] + + class SparkAPI: - - def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None: + def __init__( + self, appid: str, api_key: str, api_secret: str, spark_url: str + ) -> None: self.appid = appid self.api_key = api_key self.api_secret = api_secret self.host = urlparse(spark_url).netloc self.path = urlparse(spark_url).path - + self.spark_url = spark_url - def gen_url(self): - now = datetime.now() - date = format_date_time(mktime(now.timetuple())) + date = format_date_time(mktime(now.timetuple())) _signature = "host: " + self.host + "\n" _signature += "data: " + date + "\n" _signature += "GET " + self.path + " HTTP/1.1" - _signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"), - digestmod=hashlib.sha256).digest() + _signature_sha = hmac.new( + self.api_secret.encode("utf-8"), + _signature.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() - _signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8") + _signature_sha_base64 = base64.b64encode(_signature_sha).decode( + encoding="utf-8" + ) _authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'" - authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8') - - v = { - "authorization": authorization, - "date": date, - "host": self.host - } + authorization = base64.b64encode(_authorization.encode("utf-8")).decode( + encoding="utf-8" + ) + + v = {"authorization": authorization, "date": date, "host": self.host} url = self.spark_url + "?" + urlencode(v) - return url - - \ No newline at end of file + return url diff --git a/pilot/model/proxy/llms/tongyi.py b/pilot/model/proxy/llms/tongyi.py index f1a95928f..fb826e49c 100644 --- a/pilot/model/proxy/llms/tongyi.py +++ b/pilot/model/proxy/llms/tongyi.py @@ -8,10 +8,11 @@ logger = logging.getLogger(__name__) def tongyi_generate_stream( - model: ProxyModel, tokenizer, params, device, context_len=2048 + model: ProxyModel, tokenizer, params, device, context_len=2048 ): import dashscope from dashscope import Generation + model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") @@ -62,14 +63,14 @@ def tongyi_generate_stream( messages=history, top_p=params.get("top_p", 0.8), stream=True, - result_format='message' + result_format="message", ) for r in res: if r: - if r['status_code'] == 200: + if r["status_code"] == 200: content = r["output"]["choices"][0]["message"].get("content") yield content else: - content = r['code'] + ":" + r["message"] + content = r["code"] + ":" + r["message"] yield content diff --git a/pilot/model/proxy/llms/wenxin.py b/pilot/model/proxy/llms/wenxin.py index 262528939..acc82907c 100644 --- a/pilot/model/proxy/llms/wenxin.py +++ b/pilot/model/proxy/llms/wenxin.py @@ -6,20 +6,26 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from cachetools import cached, TTLCache + @cached(TTLCache(1, 1800)) def _build_access_token(api_key: str, secret_key: str) -> str: """ - Generate Access token according AK, SK + Generate Access token according AK, SK """ - + url = "https://aip.baidubce.com/oauth/2.0/token" - params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + params = { + "grant_type": "client_credentials", + "client_id": api_key, + "client_secret": secret_key, + } res = requests.get(url=url, params=params) - + if res.status_code == 200: return res.json().get("access_token") + def wenxin_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): @@ -28,23 +34,20 @@ def wenxin_generate_stream( "ERNIE-Bot-turbo": "eb-instant", } - model_params = model.get_params() - model_name = model_params.proxyllm_backend + model_params = model.get_params() + model_name = model_params.proxyllm_backend model_version = MODEL_VERSION.get(model_name) if not model_version: yield f"Unsupport model version {model_name}" - - proxy_api_key = model_params.proxy_api_key - proxy_api_secret = model_params.proxy_api_secret - access_token = _build_access_token(proxy_api_key, proxy_api_secret) - - headers = { - "Content-Type": "application/json", - "Accept": "application/json" - } - proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}' - + proxy_api_key = model_params.proxy_api_key + proxy_api_secret = model_params.proxy_api_secret + access_token = _build_access_token(proxy_api_key, proxy_api_secret) + + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}" + if not access_token: yield "Failed to get access token. please set the correct api_key and secret key." @@ -86,7 +89,7 @@ def wenxin_generate_stream( "messages": history, "system": system, "temperature": params.get("temperature"), - "stream": True + "stream": True, } text = "" @@ -106,6 +109,3 @@ def wenxin_generate_stream( content = obj["result"] text += content yield text - - - \ No newline at end of file diff --git a/pilot/model/proxy/llms/zhipu.py b/pilot/model/proxy/llms/zhipu.py index 3a2a1edbb..89e7dd9a0 100644 --- a/pilot/model/proxy/llms/zhipu.py +++ b/pilot/model/proxy/llms/zhipu.py @@ -7,6 +7,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType CHATGLM_DEFAULT_MODEL = "chatglm_pro" + def zhipu_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): @@ -16,9 +17,10 @@ def zhipu_generate_stream( # TODO proxy model use unified config? proxy_api_key = model_params.proxy_api_key - proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend + proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend import zhipuai + zhipuai.api_key = proxy_api_key history = [] @@ -63,4 +65,4 @@ def zhipu_generate_stream( ) for r in res.events(): if r.event == "add": - yield r.data \ No newline at end of file + yield r.data diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 100b40320..ea569b3e4 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -297,7 +297,6 @@ async def params_load( @router.post("/v1/chat/dialogue/delete") async def dialogue_delete(con_uid: str): - history_fac = ChatHistory() history_mem = history_fac.get_store_instance(con_uid) history_mem.delete() diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index be0231525..d1e6ad8ac 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -54,7 +54,7 @@ class BaseChat(ABC): ) chat_history_fac = ChatHistory() ### can configurable storage methods - self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"]) + self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"]) self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation( diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py index bb9872308..4734c8106 100644 --- a/pilot/scene/chat_agent/chat.py +++ b/pilot/scene/chat_agent/chat.py @@ -20,10 +20,11 @@ logger = logging.getLogger("chat_agent") class ChatAgent(BaseChat): chat_scene: str = ChatScene.ChatAgent.value() chat_retention_rounds = 0 + def __init__(self, chat_param: Dict): - if not chat_param['select_param']: + if not chat_param["select_param"]: raise ValueError("Please select a Plugin!") - self.select_plugins = chat_param['select_param'].split(",") + self.select_plugins = chat_param["select_param"].split(",") chat_param["chat_mode"] = ChatScene.ChatAgent super().__init__(chat_param=chat_param) @@ -31,8 +32,12 @@ class ChatAgent(BaseChat): self.plugins_prompt_generator.command_registry = CFG.command_registry # load select plugin - agent_module = CFG.SYSTEM_APP.get_component(ComponentType.AGENT_HUB, ModuleAgent) - self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins) + agent_module = CFG.SYSTEM_APP.get_component( + ComponentType.AGENT_HUB, ModuleAgent + ) + self.plugins_prompt_generator = agent_module.load_select_plugin( + self.plugins_prompt_generator, self.select_plugins + ) self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) @@ -53,4 +58,3 @@ class ChatAgent(BaseChat): def __list_to_prompt_str(self, list: List) -> str: return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) - diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py index 1c5229f89..a42fd6363 100644 --- a/pilot/scene/chat_agent/prompt.py +++ b/pilot/scene/chat_agent/prompt.py @@ -54,7 +54,7 @@ _DEFAULT_TEMPLATE = ( ) -_PROMPT_SCENE_DEFINE=( +_PROMPT_SCENE_DEFINE = ( _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH ) @@ -76,7 +76,7 @@ prompt = PromptTemplate( output_parser=PluginChatOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT ), - temperature = 1 + temperature=1 # example_selector=plugin_example, ) diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 4611aa14d..9599c1402 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -36,7 +36,7 @@ class ChatExcel(BaseChat): KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param ) ) - self.api_call = ApiCall(display_registry = CFG.command_disply) + self.api_call = ApiCall(display_registry=CFG.command_disply) super().__init__(chat_param=chat_param) def _generate_numbered_list(self) -> str: @@ -79,4 +79,4 @@ class ChatExcel(BaseChat): def stream_plugin_call(self, text): text = text.replace("\n", " ") print(f"stream_plugin_call:{text}") - return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex) \ No newline at end of file + return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex) diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py index fad30bb56..ee82b51a0 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py @@ -44,7 +44,7 @@ _DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH ) -PROMPT_SCENE_DEFINE =( +PROMPT_SCENE_DEFINE = ( _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH ) diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index 0661737f0..c9e4aa785 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -9,8 +9,21 @@ import pandas as pd import chardet import pandas as pd import numpy as np -from pyparsing import CaselessKeyword, Word, alphas, alphanums, delimitedList, Forward, Group, Optional,\ - Literal, infixNotation, opAssoc, unicodeString,Regex +from pyparsing import ( + CaselessKeyword, + Word, + alphas, + alphanums, + delimitedList, + Forward, + Group, + Optional, + Literal, + infixNotation, + opAssoc, + unicodeString, + Regex, +) from pilot.common.pd_utils import csv_colunm_foramt from pilot.common.string_utils import is_chinese_include_number @@ -21,14 +34,15 @@ def excel_colunm_format(old_name: str) -> str: new_column = new_column.replace(" ", "_") return new_column + def detect_encoding(file_path): # 读取文件的二进制数据 - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: data = f.read() # 使用 chardet 来检测文件编码 result = chardet.detect(data) - encoding = result['encoding'] - confidence = result['confidence'] + encoding = result["encoding"] + confidence = result["confidence"] return encoding, confidence @@ -39,10 +53,11 @@ def add_quotes_ex(sql: str, column_names): sql = sql.replace(column_name, f'"{column_name}"') return sql + def parse_sql(sql): # 定义关键字和标识符 select_stmt = Forward() - column = Regex(r'[\w一-龥]*') + column = Regex(r"[\w一-龥]*") table = Word(alphanums) join_expr = Forward() where_expr = Forward() @@ -62,21 +77,24 @@ def parse_sql(sql): not_in_keyword = CaselessKeyword("NOT IN") # 定义语法规则 - select_stmt <<= (select_keyword + delimitedList(column) + - from_keyword + delimitedList(table) + - Optional(join_expr) + - Optional(where_keyword + where_expr) + - Optional(group_by_keyword + group_by_expr) + - Optional(order_by_keyword + order_by_expr)) + select_stmt <<= ( + select_keyword + + delimitedList(column) + + from_keyword + + delimitedList(table) + + Optional(join_expr) + + Optional(where_keyword + where_expr) + + Optional(group_by_keyword + group_by_expr) + + Optional(order_by_keyword + order_by_expr) + ) join_expr <<= join_keyword + table + on_keyword + column + Literal("=") + column - where_expr <<= column + Literal("=") + Word(alphanums) + \ - Optional(and_keyword + where_expr) | \ - column + Literal(">") + Word(alphanums) + \ - Optional(and_keyword + where_expr) | \ - column + Literal("<") + Word(alphanums) + \ - Optional(and_keyword + where_expr) + where_expr <<= ( + column + Literal("=") + Word(alphanums) + Optional(and_keyword + where_expr) + | column + Literal(">") + Word(alphanums) + Optional(and_keyword + where_expr) + | column + Literal("<") + Word(alphanums) + Optional(and_keyword + where_expr) + ) group_by_expr <<= delimitedList(column) @@ -88,7 +106,6 @@ def parse_sql(sql): return parsed_result.asList() - def add_quotes(sql, column_names=[]): sql = sql.replace("`", "") sql = sql.replace("'", "") @@ -108,6 +125,7 @@ def deep_quotes(token, column_names=[]): new_value = token.value.replace("`", "").replace("'", "") token.value = f'"{new_value}"' + def get_select_clause(sql): parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块 @@ -123,6 +141,7 @@ def get_select_clause(sql): select_tokens.append(token) return "".join(str(token) for token in select_tokens) + def parse_select_fields(sql): parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块 fields = [] @@ -139,12 +158,14 @@ def parse_select_fields(sql): return fields + def add_quotes_to_chinese_columns(sql, column_names=[]): parsed = sqlparse.parse(sql) for stmt in parsed: process_statement(stmt, column_names) return str(parsed[0]) + def process_statement(statement, column_names=[]): if isinstance(statement, sqlparse.sql.IdentifierList): for identifier in statement.get_identifiers(): @@ -155,22 +176,23 @@ def process_statement(statement, column_names=[]): for item in statement.tokens: process_statement(item) + def process_identifier(identifier, column_names=[]): # if identifier.has_alias(): # alias = identifier.get_alias() # identifier.tokens[-1].value = '[' + alias + ']' - if hasattr(identifier, 'tokens') and identifier.value in column_names: - if is_chinese(identifier.value): + if hasattr(identifier, "tokens") and identifier.value in column_names: + if is_chinese(identifier.value): new_value = get_new_value(identifier.value) identifier.value = new_value identifier.normalized = new_value identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)] else: - if hasattr(identifier, 'tokens'): + if hasattr(identifier, "tokens"): for token in identifier.tokens: if isinstance(token, sqlparse.sql.Function): process_function(token) - elif token.ttype in sqlparse.tokens.Name : + elif token.ttype in sqlparse.tokens.Name: new_value = get_new_value(token.value) token.value = new_value token.normalized = new_value @@ -179,9 +201,12 @@ def process_identifier(identifier, column_names=[]): token.value = new_value token.normalized = new_value token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)] + + def get_new_value(value): return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """ + def process_function(function): function_params = list(function.get_parameters()) # for param in function_params: @@ -191,15 +216,18 @@ def process_function(function): if isinstance(param, sqlparse.sql.Identifier): # 判断是否需要替换字段值 # if is_chinese(param.value): - # 替换字段值 + # 替换字段值 new_value = get_new_value(param.value) # new_parameter = sqlparse.sql.Identifier(f'[{param.value}]') - function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)] + function_params[i].tokens = [ + sqlparse.sql.Token(sqlparse.tokens.Name, new_value) + ] print(str(function)) + def is_chinese(text): for char in text: - if '一' <= char <= '鿿': + if "一" <= char <= "鿿": return True return False @@ -240,7 +268,7 @@ class ExcelReader: df_tmp = pd.read_csv(file_path, encoding=encoding) self.df = pd.read_csv( file_path, - encoding = encoding, + encoding=encoding, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}, ) else: @@ -280,7 +308,6 @@ class ExcelReader: logging.error("excel sql run error!", e) raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}") - def get_df_by_sql_ex(self, sql): colunms, values = self.run(sql) return pd.DataFrame(values, columns=colunms) diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 2e3b759ee..e4c5175a6 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -16,7 +16,7 @@ class ChatWithPlugin(BaseChat): select_plugin: str = None def __init__(self, chat_param: Dict): - self.plugin_selector = chat_param["select_param"] + self.plugin_selector = chat_param["select_param"] chat_param["chat_mode"] = ChatScene.ChatExecution super().__init__(chat_param=chat_param) self.plugins_prompt_generator = PluginPromptGenerator() diff --git a/pilot/server/base.py b/pilot/server/base.py index 3cc9b16df..42b1f6a33 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -15,7 +15,6 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi sys.path.append(ROOT_PATH) - def signal_handler(sig, frame): print("in order to avoid chroma db atexit problem") os._exit(0) @@ -32,7 +31,6 @@ def async_db_summary(system_app: SystemApp): def server_init(args, system_app: SystemApp): from pilot.base_modules.agent.commands.command_mange import CommandRegistry - # logger.info(f"args: {args}") # init config @@ -44,8 +42,6 @@ def server_init(args, system_app: SystemApp): # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) - - # Loader plugins and commands command_categories = [ "pilot.base_modules.agent.commands.built_in.audio_text", @@ -126,4 +122,3 @@ class WebWerverParameters(BaseParameters): }, ) light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"}) - diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py index 06f2c4a16..d700de94c 100644 --- a/pilot/server/component_configs.py +++ b/pilot/server/component_configs.py @@ -29,6 +29,7 @@ def initialize_components( system_app.register_instance(controller) from pilot.base_modules.agent.controller import module_agent + system_app.register_instance(module_agent) _initialize_embedding_model( diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 71ef9f8d8..635ae99cf 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -31,7 +31,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1 from pilot.openapi.base import validation_exception_handler from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 -from pilot.base_modules.agent.commands.disply_type.show_chart_gen import static_message_img_path +from pilot.base_modules.agent.commands.disply_type.show_chart_gen import ( + static_message_img_path, +) from pilot.model.cluster import initialize_worker_manager_in_client from pilot.utils.utils import ( setup_logging, @@ -56,6 +58,8 @@ def swagger_monkey_patch(*args, **kwargs): swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js", swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css" ) + + app = FastAPI() applications.get_swagger_ui_html = swagger_monkey_patch @@ -73,14 +77,14 @@ app.add_middleware( ) -app.include_router(api_v1, prefix="/api", tags=["Chat"]) -app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) +app.include_router(api_v1, prefix="/api", tags=["Chat"]) +app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"]) app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"]) -app.include_router(knowledge_router, tags=["Knowledge"]) -app.include_router(prompt_router, tags=["Prompt"]) +app.include_router(knowledge_router, tags=["Knowledge"]) +app.include_router(prompt_router, tags=["Prompt"]) def mount_static_files(app): @@ -98,6 +102,7 @@ def mount_static_files(app): app.add_exception_handler(RequestValidationError, validation_exception_handler) + def _get_webserver_params(args: List[str] = None): from pilot.utils.parameter_utils import EnvArgumentParser @@ -106,6 +111,7 @@ def _get_webserver_params(args: List[str] = None): ) return WebWerverParameters(**vars(parser.parse_args(args=args))) + def initialize_app(param: WebWerverParameters = None, args: List[str] = None): """Initialize app If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.