From 8897d6e8fd14328d143c9f9c2e2c9570af084593 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 15 Mar 2024 15:42:46 +0800 Subject: [PATCH] chore: Add pylint for storage (#1298) --- .mypy.ini | 19 +- Makefile | 8 +- assets/schema/dbgpt.sql | 2 +- .../api_v1/editor/_chat_history/__init__.py | 4 + .../api_v1/editor/_chat_history}/base.py | 23 +-- .../_chat_history}/chat_hisotry_factory.py | 40 ++-- .../editor/_chat_history}/meta_db_history.py | 75 ++++---- .../openapi/api_v1/editor/api_editor_v1.py | 8 +- dbgpt/app/scene/operators/app_operator.py | 3 +- dbgpt/core/awel/trigger/iterator_trigger.py | 18 +- dbgpt/model/cluster/embedding/loader.py | 4 +- .../cluster/embedding/remote_embedding.py | 3 +- .../model/cluster/worker/embedding_worker.py | 10 +- dbgpt/rag/chunk_manager.py | 3 +- dbgpt/serve/flow/service/service.py | 1 + dbgpt/storage/__init__.py | 3 + dbgpt/storage/cache/__init__.py | 7 +- dbgpt/storage/cache/embedding_cache.py | 1 + dbgpt/storage/cache/llm_cache.py | 83 ++++++-- dbgpt/storage/cache/manager.py | 54 ++++-- .../cache/{operator.py => operators.py} | 70 ++++--- dbgpt/storage/cache/protocal/__init__.py | 0 dbgpt/storage/cache/protocol/__init__.py | 1 + dbgpt/storage/cache/storage/__init__.py | 1 + dbgpt/storage/cache/storage/base.py | 26 ++- dbgpt/storage/cache/storage/disk/__init__.py | 1 + .../cache/storage/disk/disk_storage.py | 22 ++- .../cache/storage/tests/test_storage.py | 2 - dbgpt/storage/chat_history/__init__.py | 18 ++ dbgpt/storage/chat_history/chat_history_db.py | 23 ++- dbgpt/storage/chat_history/storage_adapter.py | 55 ++++-- .../chat_history/store_type/__init__.py | 0 .../chat_history/store_type/duckdb_history.py | 182 ------------------ .../chat_history/store_type/file_history.py | 50 ----- .../chat_history/store_type/mem_history.py | 27 --- dbgpt/storage/metadata/__init__.py | 7 +- dbgpt/storage/metadata/_base_dao.py | 34 ++-- dbgpt/storage/metadata/db_factory.py | 7 + dbgpt/storage/metadata/db_manager.py | 139 +++++++------ dbgpt/storage/metadata/db_storage.py | 18 +- dbgpt/storage/metadata/meta_data.py | 0 dbgpt/storage/schema.py | 25 ++- dbgpt/storage/vector_store/__init__.py | 1 + dbgpt/storage/vector_store/base.py | 83 ++++---- dbgpt/storage/vector_store/chroma_store.py | 42 ++-- dbgpt/storage/vector_store/connector.py | 85 +++++--- dbgpt/storage/vector_store/milvus_store.py | 73 ++++--- dbgpt/storage/vector_store/pgvector_store.py | 47 +++-- dbgpt/storage/vector_store/weaviate_store.py | 40 ++-- requirements/lint-requirements.txt | 3 +- 50 files changed, 784 insertions(+), 667 deletions(-) create mode 100644 dbgpt/app/openapi/api_v1/editor/_chat_history/__init__.py rename dbgpt/{storage/chat_history => app/openapi/api_v1/editor/_chat_history}/base.py (72%) rename dbgpt/{storage/chat_history => app/openapi/api_v1/editor/_chat_history}/chat_hisotry_factory.py (56%) rename dbgpt/{storage/chat_history/store_type => app/openapi/api_v1/editor/_chat_history}/meta_db_history.py (52%) rename dbgpt/storage/cache/{operator.py => operators.py} (79%) delete mode 100644 dbgpt/storage/cache/protocal/__init__.py create mode 100644 dbgpt/storage/cache/protocol/__init__.py delete mode 100644 dbgpt/storage/chat_history/store_type/__init__.py delete mode 100644 dbgpt/storage/chat_history/store_type/duckdb_history.py delete mode 100644 dbgpt/storage/chat_history/store_type/file_history.py delete mode 100644 dbgpt/storage/chat_history/store_type/mem_history.py delete mode 100644 dbgpt/storage/metadata/meta_data.py diff --git a/.mypy.ini b/.mypy.ini index 55abad8f7..d0578aa6d 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -8,8 +8,8 @@ follow_imports = skip [mypy-dbgpt.datasource.*] follow_imports = skip -[mypy-dbgpt.storage.*] -follow_imports = skip +# [mypy-dbgpt.storage.*] +# follow_imports = skip [mypy-dbgpt.serve.*] follow_imports = skip @@ -57,4 +57,17 @@ ignore_missing_imports = True [mypy-spacy.*] ignore_missing_imports = True -follow_imports = skip \ No newline at end of file +follow_imports = skip + +# Storage +[mypy-msgpack.*] +ignore_missing_imports = True + +[mypy-rocksdict.*] +ignore_missing_imports = True + +[mypy-weaviate.*] +ignore_missing_imports = True + +[mypy-pymilvus.*] +ignore_missing_imports = True diff --git a/Makefile b/Makefile index 2780c15d0..298133819 100644 --- a/Makefile +++ b/Makefile @@ -50,6 +50,7 @@ fmt: setup ## Format Python code # https://flake8.pycqa.org/en/latest/ $(VENV_BIN)/flake8 dbgpt/core/ $(VENV_BIN)/flake8 dbgpt/rag/ + $(VENV_BIN)/flake8 dbgpt/storage/ # TODO: More package checks with flake8. .PHONY: fmt-check @@ -60,8 +61,7 @@ fmt-check: setup ## Check Python code formatting and style without making change $(VENV_BIN)/blackdoc --check dbgpt examples $(VENV_BIN)/flake8 dbgpt/core/ $(VENV_BIN)/flake8 dbgpt/rag/ - # $(VENV_BIN)/blackdoc --check dbgpt examples - # $(VENV_BIN)/flake8 dbgpt/core/ + $(VENV_BIN)/flake8 dbgpt/storage/ .PHONY: pre-commit pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing @@ -77,8 +77,10 @@ test-doc: $(VENV)/.testenv ## Run doctests .PHONY: mypy mypy: $(VENV)/.testenv ## Run mypy checks # https://github.com/python/mypy - $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/ $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ + # rag depends on core and storage, so we not need to check it again. + # $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/ + # $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/ # TODO: More package checks with mypy. .PHONY: coverage diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index 1b8315402..86f7ed740 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -83,7 +83,7 @@ CREATE TABLE IF NOT EXISTS `chat_history` `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', - `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', diff --git a/dbgpt/app/openapi/api_v1/editor/_chat_history/__init__.py b/dbgpt/app/openapi/api_v1/editor/_chat_history/__init__.py new file mode 100644 index 000000000..797d4f804 --- /dev/null +++ b/dbgpt/app/openapi/api_v1/editor/_chat_history/__init__.py @@ -0,0 +1,4 @@ +"""Old chat history module. + +Just used by editor. +""" diff --git a/dbgpt/storage/chat_history/base.py b/dbgpt/app/openapi/api_v1/editor/_chat_history/base.py similarity index 72% rename from dbgpt/storage/chat_history/base.py rename to dbgpt/app/openapi/api_v1/editor/_chat_history/base.py index 2f83f0c5d..2f1c32678 100644 --- a/dbgpt/storage/chat_history/base.py +++ b/dbgpt/app/openapi/api_v1/editor/_chat_history/base.py @@ -8,10 +8,10 @@ from dbgpt.core.interface.message import OnceConversation class MemoryStoreType(Enum): - File = "file" - Memory = "memory" + # File = "file" + # Memory = "memory" DB = "db" - DuckDb = "duckdb" + # DuckDb = "duckdb" class BaseChatHistoryMemory(ABC): @@ -24,18 +24,14 @@ class BaseChatHistoryMemory(ABC): def messages(self) -> List[OnceConversation]: # type: ignore """Retrieve the messages from the local file""" - @abstractmethod - def create(self, user_name: str) -> None: - """Append the message to the record in the local file""" + # @abstractmethod + # def create(self, user_name: str) -> None: + # """Append the message to the record in the local file""" @abstractmethod def append(self, message: OnceConversation) -> None: """Append the message to the record in the local file""" - # @abstractmethod - # def clear(self) -> None: - # """Clear session memory from the local file""" - @abstractmethod def update(self, messages: List[OnceConversation]) -> None: pass @@ -45,14 +41,11 @@ class BaseChatHistoryMemory(ABC): pass @abstractmethod - def conv_info(self, conv_uid: Optional[str] = None) -> None: - pass - - @abstractmethod - def get_messages(self) -> List[OnceConversation]: + def get_messages(self) -> List[Dict]: pass @staticmethod + @abstractmethod def conv_list( user_name: Optional[str] = None, sys_code: Optional[str] = None ) -> List[Dict]: diff --git a/dbgpt/storage/chat_history/chat_hisotry_factory.py b/dbgpt/app/openapi/api_v1/editor/_chat_history/chat_hisotry_factory.py similarity index 56% rename from dbgpt/storage/chat_history/chat_hisotry_factory.py rename to dbgpt/app/openapi/api_v1/editor/_chat_history/chat_hisotry_factory.py index 9a556fe5c..249f30a6e 100644 --- a/dbgpt/storage/chat_history/chat_hisotry_factory.py +++ b/dbgpt/app/openapi/api_v1/editor/_chat_history/chat_hisotry_factory.py @@ -1,13 +1,16 @@ +"""Module for chat history factory. + +It will remove in the future, just support db store type now. +""" import logging from typing import Type from dbgpt._private.config import Config -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory -from .base import MemoryStoreType +from .base import BaseChatHistoryMemory, MemoryStoreType # Import first for auto create table -from .store_type.meta_db_history import DbHistoryMemory +from .meta_db_history import DbHistoryMemory # TODO remove global variable CFG = Config() @@ -20,13 +23,6 @@ class ChatHistory: self.mem_store_class_map = {} # Just support db store type after v0.4.6 - # from .store_type.duckdb_history import DuckdbHistoryMemory - # from .store_type.file_history import FileHistoryMemory - # from .store_type.mem_history import MemHistoryMemory - # self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory - # self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory - # self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory - self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory: @@ -53,24 +49,26 @@ class ChatHistory: Raises: ValueError: Invalid store type """ - from .store_type.duckdb_history import DuckdbHistoryMemory - from .store_type.file_history import FileHistoryMemory - from .store_type.mem_history import MemHistoryMemory - - if store_type == MemHistoryMemory.store_type: + if store_type == "memory": logger.error( "Not support memory store type, just support db store type now" ) raise ValueError(f"Invalid store type: {store_type}") - if store_type == FileHistoryMemory.store_type: + if store_type == "file": logger.error("Not support file store type, just support db store type now") raise ValueError(f"Invalid store type: {store_type}") - if store_type == DuckdbHistoryMemory.store_type: - link1 = "https://docs.dbgpt.site/docs/faq/install#q6-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-sqlitel" - link2 = "https://docs.dbgpt.site/docs/faq/install#q7-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-mysql" + if store_type == "duckdb": + link1 = ( + "https://docs.dbgpt.site/docs/latest/faq/install#q6-how-to-migrate-meta" + "-table-chat_history-and-connect_config-from-duckdb-to-sqlite" + ) + link2 = ( + "https://docs.dbgpt.site/docs/latest/faq/install/#q7-how-to-migrate-" + "meta-table-chat_history-and-connect_config-from-duckdb-to-mysql" + ) logger.error( - "Not support duckdb store type after v0.4.6, just support db store type now, " - f"you can migrate your message according to {link1} or {link2}" + "Not support duckdb store type after v0.4.6, just support db store " + f"type now, you can migrate your message according to {link1} or {link2}" ) raise ValueError(f"Invalid store type: {store_type}") diff --git a/dbgpt/storage/chat_history/store_type/meta_db_history.py b/dbgpt/app/openapi/api_v1/editor/_chat_history/meta_db_history.py similarity index 52% rename from dbgpt/storage/chat_history/store_type/meta_db_history.py rename to dbgpt/app/openapi/api_v1/editor/_chat_history/meta_db_history.py index ca08c69fd..0ac1dc4be 100644 --- a/dbgpt/storage/chat_history/store_type/meta_db_history.py +++ b/dbgpt/app/openapi/api_v1/editor/_chat_history/meta_db_history.py @@ -4,64 +4,76 @@ from typing import Dict, List, Optional from dbgpt._private.config import Config from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType from dbgpt.storage.chat_history.chat_history_db import ChatHistoryDao, ChatHistoryEntity +from .base import BaseChatHistoryMemory, MemoryStoreType + CFG = Config() logger = logging.getLogger(__name__) class DbHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.DB.value + """Db history memory storage. + + It is deprecated. + """ + + store_type: str = MemoryStoreType.DB.value # type: ignore def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id self.chat_history_dao = ChatHistoryDao() def messages(self) -> List[OnceConversation]: - chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( + chat_history: Optional[ChatHistoryEntity] = self.chat_history_dao.get_by_uid( self.chat_seesion_id ) if chat_history: context = chat_history.messages if context: - conversations: List[OnceConversation] = json.loads(context) + conversations: List[OnceConversation] = json.loads( + context # type: ignore + ) return conversations return [] - def create(self, chat_mode, summary: str, user_name: str) -> None: - try: - chat_history: ChatHistoryEntity = ChatHistoryEntity() - chat_history.chat_mode = chat_mode - chat_history.summary = summary - chat_history.user_name = user_name - - self.chat_history_dao.raw_update(chat_history) - except Exception as e: - logger.error("init create conversation log error!" + str(e)) - + # def create(self, chat_mode, summary: str, user_name: str) -> None: + # try: + # chat_history: ChatHistoryEntity = ChatHistoryEntity() + # chat_history.chat_mode = chat_mode + # chat_history.summary = summary + # chat_history.user_name = user_name + # + # self.chat_history_dao.raw_update(chat_history) + # except Exception as e: + # logger.error("init create conversation log error!" + str(e)) + # def append(self, once_message: OnceConversation) -> None: logger.debug(f"db history append: {once_message}") - chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid( + chat_history: Optional[ChatHistoryEntity] = self.chat_history_dao.get_by_uid( self.chat_seesion_id ) - conversations: List[OnceConversation] = [] + conversations: List[Dict] = [] + latest_user_message = once_message.get_latest_user_message() + summary = latest_user_message.content if latest_user_message else "" if chat_history: context = chat_history.messages if context: - conversations = json.loads(context) + conversations = json.loads(context) # type: ignore else: - chat_history.summary = once_message.get_latest_user_message().content + chat_history.summary = summary # type: ignore else: - chat_history: ChatHistoryEntity = ChatHistoryEntity() - chat_history.conv_uid = self.chat_seesion_id - chat_history.chat_mode = once_message.chat_mode - chat_history.user_name = once_message.user_name - chat_history.sys_code = once_message.sys_code - chat_history.summary = once_message.get_latest_user_message().content + chat_history = ChatHistoryEntity() + chat_history.conv_uid = self.chat_seesion_id # type: ignore + chat_history.chat_mode = once_message.chat_mode # type: ignore + chat_history.user_name = once_message.user_name # type: ignore + chat_history.sys_code = once_message.sys_code # type: ignore + chat_history.summary = summary # type: ignore conversations.append(_conversation_to_dict(once_message)) - chat_history.messages = json.dumps(conversations, ensure_ascii=False) + chat_history.messages = json.dumps( # type: ignore + conversations, ensure_ascii=False + ) self.chat_history_dao.raw_update(chat_history) @@ -72,18 +84,13 @@ class DbHistoryMemory(BaseChatHistoryMemory): def delete(self) -> bool: self.chat_history_dao.raw_delete(self.chat_seesion_id) + return True - def conv_info(self, conv_uid: str = None) -> None: - logger.info("conv_info:{}", conv_uid) - chat_history = self.chat_history_dao.get_by_uid(conv_uid) - return chat_history.__dict__ - - def get_messages(self) -> List[OnceConversation]: - # logger.info("get_messages:{}", self.chat_seesion_id) + def get_messages(self) -> List[Dict]: chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id) if chat_history: context = chat_history.messages - return json.loads(context) + return json.loads(context) # type: ignore return [] @staticmethod diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index 1e1ec9662..5421dd4e0 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import List +from typing import Dict, List from fastapi import APIRouter, Body, Depends @@ -23,9 +23,9 @@ from dbgpt.app.openapi.editor_view_model import ( ) from dbgpt.app.scene import ChatFactory from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader -from dbgpt.core.interface.message import OnceConversation from dbgpt.serve.conversation.serve import Serve as ConversationServe -from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory + +from ._chat_history.chat_hisotry_factory import ChatHistory router = APIRouter() CFG = Config() @@ -201,7 +201,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()) chat_history_fac = ChatHistory() history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid) - history_messages: List[OnceConversation] = history_mem.get_messages() + history_messages: List[Dict] = history_mem.get_messages() if history_messages: dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name) diff --git a/dbgpt/app/scene/operators/app_operator.py b/dbgpt/app/scene/operators/app_operator.py index c4fc223f7..17d21d370 100644 --- a/dbgpt/app/scene/operators/app_operator.py +++ b/dbgpt/app/scene/operators/app_operator.py @@ -21,10 +21,9 @@ from dbgpt.core.awel import ( from dbgpt.core.operators import ( BufferedConversationMapperOperator, HistoryPromptBuilderOperator, - LLMBranchOperator, ) from dbgpt.model.operators import LLMOperator, StreamingLLMOperator -from dbgpt.storage.cache.operator import ( +from dbgpt.storage.cache.operators import ( CachedModelOperator, CachedModelStreamOperator, CacheManager, diff --git a/dbgpt/core/awel/trigger/iterator_trigger.py b/dbgpt/core/awel/trigger/iterator_trigger.py index 15cbab637..ce6907b3e 100644 --- a/dbgpt/core/awel/trigger/iterator_trigger.py +++ b/dbgpt/core/awel/trigger/iterator_trigger.py @@ -33,7 +33,7 @@ async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIter yield iter_data -class IteratorTrigger(Trigger): +class IteratorTrigger(Trigger[List[Tuple[Any, Any]]]): """Trigger for iterator data. Trigger the dag with iterator data. @@ -46,6 +46,7 @@ class IteratorTrigger(Trigger): data: IterDataType, parallel_num: int = 1, streaming_call: bool = False, + show_progress: bool = True, **kwargs ): """Create a IteratorTrigger. @@ -60,6 +61,7 @@ class IteratorTrigger(Trigger): self._iter_data = data self._parallel_num = parallel_num self._streaming_call = streaming_call + self._show_progress = show_progress super().__init__(**kwargs) async def trigger( @@ -132,17 +134,27 @@ class IteratorTrigger(Trigger): async def call_stream(call_data: Any): async for out in await end_node.call_stream(call_data): yield out + await dag._after_dag_end() - async def run_node(call_data: Any): + async def run_node(call_data: Any) -> Tuple[Any, Any]: async with semaphore: if streaming_call: task_output = call_stream(call_data) else: task_output = await end_node.call(call_data) + await dag._after_dag_end() return call_data, task_output tasks = [] + + if self._show_progress: + from tqdm.asyncio import tqdm_asyncio + + async_module = tqdm_asyncio + else: + async_module = asyncio # type: ignore + async for data in _to_async_iterator(self._iter_data, task_id): tasks.append(run_node(data)) - results = await asyncio.gather(*tasks) + results: List[Tuple[Any, Any]] = await async_module.gather(*tasks) return results diff --git a/dbgpt/model/cluster/embedding/loader.py b/dbgpt/model/cluster/embedding/loader.py index 2bb1611b5..5ce7d6c28 100644 --- a/dbgpt/model/cluster/embedding/loader.py +++ b/dbgpt/model/cluster/embedding/loader.py @@ -10,7 +10,7 @@ from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer if TYPE_CHECKING: from langchain.embeddings.base import Embeddings as LangChainEmbeddings - from dbgpt.rag.embedding import Embeddings + from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings class EmbeddingLoader: @@ -47,7 +47,7 @@ class EmbeddingLoader: openapi_param["model_name"] = proxy_param.proxy_backend return OpenAPIEmbeddings(**openapi_param) else: - from langchain.embeddings import HuggingFaceEmbeddings + from dbgpt.rag.embedding import HuggingFaceEmbeddings kwargs = param.build_kwargs(model_name=param.model_path) return HuggingFaceEmbeddings(**kwargs) diff --git a/dbgpt/model/cluster/embedding/remote_embedding.py b/dbgpt/model/cluster/embedding/remote_embedding.py index d45c9dd85..56ede440a 100644 --- a/dbgpt/model/cluster/embedding/remote_embedding.py +++ b/dbgpt/model/cluster/embedding/remote_embedding.py @@ -1,7 +1,6 @@ from typing import List -from langchain.embeddings.base import Embeddings - +from dbgpt.core import Embeddings from dbgpt.model.cluster.manager_base import WorkerManager diff --git a/dbgpt/model/cluster/worker/embedding_worker.py b/dbgpt/model/cluster/worker/embedding_worker.py index eb8cacd90..3b2aff0fd 100644 --- a/dbgpt/model/cluster/worker/embedding_worker.py +++ b/dbgpt/model/cluster/worker/embedding_worker.py @@ -20,14 +20,8 @@ logger = logging.getLogger(__name__) class EmbeddingsModelWorker(ModelWorker): def __init__(self) -> None: - try: - from langchain.embeddings import HuggingFaceEmbeddings - from langchain.embeddings.base import Embeddings - except ImportError as exc: - raise ImportError( - "Could not import langchain.embeddings.HuggingFaceEmbeddings python package. " - "Please install it with `pip install langchain`." - ) from exc + from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings + self._embeddings_impl: Embeddings = None self._model_params = None self.model_name = None diff --git a/dbgpt/rag/chunk_manager.py b/dbgpt/rag/chunk_manager.py index 094876a66..be43ce96c 100644 --- a/dbgpt/rag/chunk_manager.py +++ b/dbgpt/rag/chunk_manager.py @@ -3,8 +3,7 @@ from enum import Enum from typing import Any, List, Optional -from pydantic import BaseModel, Field - +from dbgpt._private.pydantic import BaseModel, Field from dbgpt.rag.chunk import Chunk, Document from dbgpt.rag.extractor.base import Extractor from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 0758d1e71..6deeb8985 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -411,6 +411,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): end_node = cast(BaseOperator, leaf_nodes[0]) async for output in _chat_with_dag_task(end_node, request, incremental): yield output + await dag._after_dag_end() def _parse_flow_category(self, dag: DAG) -> FlowCategory: """Parse the flow category diff --git a/dbgpt/storage/__init__.py b/dbgpt/storage/__init__.py index e69de29bb..a0ff32750 100644 --- a/dbgpt/storage/__init__.py +++ b/dbgpt/storage/__init__.py @@ -0,0 +1,3 @@ +"""Module of storage.""" + +from .schema import DBType # noqa: F401 diff --git a/dbgpt/storage/cache/__init__.py b/dbgpt/storage/cache/__init__.py index 80b23bf25..3cee3f63f 100644 --- a/dbgpt/storage/cache/__init__.py +++ b/dbgpt/storage/cache/__init__.py @@ -1,6 +1,7 @@ -from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue -from dbgpt.storage.cache.manager import CacheManager, initialize_cache -from dbgpt.storage.cache.storage.base import MemoryCacheStorage +"""Module for cache storage.""" +from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue # noqa: F401 +from .manager import CacheManager, initialize_cache # noqa: F401 +from .storage.base import MemoryCacheStorage # noqa: F401 __all__ = [ "LLMCacheKey", diff --git a/dbgpt/storage/cache/embedding_cache.py b/dbgpt/storage/cache/embedding_cache.py index e69de29bb..46cc0936d 100644 --- a/dbgpt/storage/cache/embedding_cache.py +++ b/dbgpt/storage/cache/embedding_cache.py @@ -0,0 +1 @@ +"""Embeddings cache.""" diff --git a/dbgpt/storage/cache/llm_cache.py b/dbgpt/storage/cache/llm_cache.py index 441c67a6c..a7d7ec39e 100644 --- a/dbgpt/storage/cache/llm_cache.py +++ b/dbgpt/storage/cache/llm_cache.py @@ -1,21 +1,26 @@ +"""Cache client for LLM.""" + import hashlib from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast -from dbgpt.core import ModelOutput, Serializer +from dbgpt.core import ModelOutput from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue -from dbgpt.model.base import ModelType -from dbgpt.storage.cache.manager import CacheManager + +from .manager import CacheManager @dataclass class LLMCacheKeyData: + """Cache key data for LLM.""" + prompt: str model_name: str temperature: Optional[float] = 0.7 max_new_tokens: Optional[int] = None top_p: Optional[float] = 1.0 - model_type: Optional[str] = ModelType.HF + # See dbgpt.model.base.ModelType + model_type: Optional[str] = "huggingface" CacheOutputType = Union[ModelOutput, List[ModelOutput]] @@ -23,12 +28,15 @@ CacheOutputType = Union[ModelOutput, List[ModelOutput]] @dataclass class LLMCacheValueData: + """Cache value data for LLM.""" + output: CacheOutputType user: Optional[str] = None - _is_list: Optional[bool] = False + _is_list: bool = False @staticmethod def from_dict(**kwargs) -> "LLMCacheValueData": + """Create LLMCacheValueData object from dict.""" output = kwargs.get("output") if not output: raise ValueError("Can't new LLMCacheValueData object, output is None") @@ -46,6 +54,7 @@ class LLMCacheValueData: return LLMCacheValueData(**kwargs) def to_dict(self) -> Dict: + """Convert to dict.""" output = self.output is_list = False if isinstance(output, list): @@ -53,16 +62,18 @@ class LLMCacheValueData: is_list = True for out in output: output_list.append(out.to_dict()) - output = output_list + output = output_list # type: ignore else: - output = output.to_dict() + output = output.to_dict() # type: ignore return {"output": output, "_is_list": is_list, "user": self.user} @property def is_list(self) -> bool: + """Return whether the output is a list.""" return self._is_list def __str__(self) -> str: + """Return string representation.""" if not isinstance(self.output, list): return f"user: {self.user}, output: {self.output}" else: @@ -70,74 +81,116 @@ class LLMCacheValueData: class LLMCacheKey(CacheKey[LLMCacheKeyData]): + """Cache key for LLM.""" + def __init__(self, **kwargs) -> None: + """Create a new instance of LLMCacheKey.""" super().__init__() self.config = LLMCacheKeyData(**kwargs) def __hash__(self) -> int: + """Return the hash value of the object.""" serialize_bytes = self.serialize() return int(hashlib.sha256(serialize_bytes).hexdigest(), 16) def __eq__(self, other: Any) -> bool: + """Check equality with another key.""" if not isinstance(other, LLMCacheKey): return False return self.config == other.config def get_hash_bytes(self) -> bytes: + """Return the byte array of hash value. + + Returns: + bytes: The byte array of hash value. + """ serialize_bytes = self.serialize() return hashlib.sha256(serialize_bytes).digest() def to_dict(self) -> Dict: + """Convert to dict.""" return asdict(self.config) def get_value(self) -> LLMCacheKeyData: + """Return the real object of current cache key.""" return self.config class LLMCacheValue(CacheValue[LLMCacheValueData]): + """Cache value for LLM.""" + def __init__(self, **kwargs) -> None: + """Create a new instance of LLMCacheValue.""" super().__init__() self.value = LLMCacheValueData.from_dict(**kwargs) def to_dict(self) -> Dict: + """Convert to dict.""" return self.value.to_dict() def get_value(self) -> LLMCacheValueData: + """Return the underlying real value.""" return self.value def __str__(self) -> str: + """Return string representation.""" return f"value: {str(self.value)}" class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]): + """Cache client for LLM.""" + def __init__(self, cache_manager: CacheManager) -> None: + """Create a new instance of LLMCacheClient.""" super().__init__() self._cache_manager: CacheManager = cache_manager async def get( - self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None + self, + key: LLMCacheKey, # type: ignore + cache_config: Optional[CacheConfig] = None, ) -> Optional[LLMCacheValue]: - return await self._cache_manager.get(key, LLMCacheValue, cache_config) + """Retrieve a value from the cache using the provided key. + + Args: + key (LLMCacheKey): The key to get cache + cache_config (Optional[CacheConfig]): Cache config + + Returns: + Optional[LLMCacheValue]: The value retrieved according to key. If cache key + not exist, return None. + """ + return cast( + LLMCacheValue, + await self._cache_manager.get(key, LLMCacheValue, cache_config), + ) async def set( self, - key: LLMCacheKey, - value: LLMCacheValue, + key: LLMCacheKey, # type: ignore + value: LLMCacheValue, # type: ignore cache_config: Optional[CacheConfig] = None, ) -> None: + """Set a value in the cache for the provided key.""" return await self._cache_manager.set(key, value, cache_config) async def exists( - self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None + self, + key: LLMCacheKey, # type: ignore + cache_config: Optional[CacheConfig] = None, ) -> bool: + """Check if a key exists in the cache.""" return await self.get(key, cache_config) is not None - def new_key(self, **kwargs) -> LLMCacheKey: + def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore + """Create a cache key with params.""" key = LLMCacheKey(**kwargs) key.set_serializer(self._cache_manager.serializer) return key - def new_value(self, **kwargs) -> LLMCacheValue: + def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore + """Create a cache value with params.""" value = LLMCacheValue(**kwargs) value.set_serializer(self._cache_manager.serializer) return value diff --git a/dbgpt/storage/cache/manager.py b/dbgpt/storage/cache/manager.py index 97063c347..586518346 100644 --- a/dbgpt/storage/cache/manager.py +++ b/dbgpt/storage/cache/manager.py @@ -1,24 +1,31 @@ +"""Cache manager.""" + import logging from abc import ABC, abstractmethod from concurrent.futures import Executor -from typing import Optional, Type +from typing import Optional, Type, cast from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer from dbgpt.core.interface.cache import K, V -from dbgpt.storage.cache.storage.base import CacheStorage from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async +from .storage.base import CacheStorage + logger = logging.getLogger(__name__) class CacheManager(BaseComponent, ABC): + """The cache manager interface.""" + name = ComponentType.MODEL_CACHE_MANAGER def __init__(self, system_app: SystemApp | None = None): + """Create cache manager.""" super().__init__(system_app) def init_app(self, system_app: SystemApp): + """Initialize cache manager.""" self.system_app = system_app @abstractmethod @@ -28,7 +35,7 @@ class CacheManager(BaseComponent, ABC): value: CacheValue[V], cache_config: Optional[CacheConfig] = None, ): - """Set cache""" + """Set cache with key.""" @abstractmethod async def get( @@ -36,27 +43,30 @@ class CacheManager(BaseComponent, ABC): key: CacheKey[K], cls: Type[Serializable], cache_config: Optional[CacheConfig] = None, - ) -> CacheValue[V]: - """Get cache with key""" + ) -> Optional[CacheValue[V]]: + """Retrieve cache with key.""" @property @abstractmethod def serializer(self) -> Serializer: - """Get cache serializer""" + """Return serializer to serialize/deserialize cache value.""" class LocalCacheManager(CacheManager): + """Local cache manager.""" + def __init__( self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage ) -> None: + """Create local cache manager.""" super().__init__(system_app) self._serializer = serializer self._storage = storage @property def executor(self) -> Executor: - """Return executor to submit task""" - self._executor = self.system_app.get_component( + """Return executor.""" + return self.system_app.get_component( # type: ignore ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() @@ -66,6 +76,7 @@ class LocalCacheManager(CacheManager): value: CacheValue[V], cache_config: Optional[CacheConfig] = None, ): + """Set cache with key.""" if self._storage.support_async(): await self._storage.aset(key, value, cache_config) else: @@ -78,7 +89,8 @@ class LocalCacheManager(CacheManager): key: CacheKey[K], cls: Type[Serializable], cache_config: Optional[CacheConfig] = None, - ) -> CacheValue[V]: + ) -> Optional[CacheValue[V]]: + """Retrieve cache with key.""" if self._storage.support_async(): item_bytes = await self._storage.aget(key, cache_config) else: @@ -87,30 +99,42 @@ class LocalCacheManager(CacheManager): ) if not item_bytes: return None - return self._serializer.deserialize(item_bytes.value_data, cls) + return cast( + CacheValue[V], self._serializer.deserialize(item_bytes.value_data, cls) + ) @property def serializer(self) -> Serializer: + """Return serializer to serialize/deserialize cache value.""" return self._serializer def initialize_cache( system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str ): - from dbgpt.storage.cache.storage.base import MemoryCacheStorage + """Initialize cache manager. + + Args: + system_app (SystemApp): The system app. + storage_type (str): The storage type. + max_memory_mb (int): The max memory in MB. + persist_dir (str): The persist directory. + """ from dbgpt.util.serialization.json_serialization import JsonSerializer - cache_storage = None + from .storage.base import MemoryCacheStorage + if storage_type == "disk": try: - from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage + from .storage.disk.disk_storage import DiskCacheStorage - cache_storage = DiskCacheStorage( + cache_storage: CacheStorage = DiskCacheStorage( persist_dir, mem_table_buffer_mb=max_memory_mb ) except ImportError as e: logger.warn( - f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}" + f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error " + f"message: {str(e)}" ) cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb) else: diff --git a/dbgpt/storage/cache/operator.py b/dbgpt/storage/cache/operators.py similarity index 79% rename from dbgpt/storage/cache/operator.py rename to dbgpt/storage/cache/operators.py index 40260cb54..df33b700d 100644 --- a/dbgpt/storage/cache/operator.py +++ b/dbgpt/storage/cache/operators.py @@ -1,5 +1,6 @@ +"""Operators for processing model outputs with caching support.""" import logging -from typing import AsyncIterator, Dict, List, Union +from typing import AsyncIterator, Dict, List, Optional, Union, cast from dbgpt.core import ModelOutput, ModelRequest from dbgpt.core.awel import ( @@ -10,7 +11,9 @@ from dbgpt.core.awel import ( StreamifyAbsOperator, TransformStreamAbsOperator, ) -from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue + +from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue +from .manager import CacheManager logger = logging.getLogger(__name__) @@ -26,15 +29,17 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput]) **kwargs: Additional keyword arguments. Methods: - streamify: Processes a stream of inputs with cache support, yielding model outputs. + streamify: Processes a stream of inputs with cache support, yielding model + outputs. """ def __init__(self, cache_manager: CacheManager, **kwargs) -> None: + """Create a new instance of CachedModelStreamOperator.""" super().__init__(**kwargs) self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) - async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]: + async def streamify(self, input_value: ModelRequest): """Process inputs as a stream with cache support and yield model outputs. Args: @@ -45,10 +50,13 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput]) """ cache_dict = _parse_cache_key_dict(input_value) llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict) - llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key) + llm_cache_value = await self._client.get(llm_cache_key) logger.info(f"llm_cache_value: {llm_cache_value}") - for out in llm_cache_value.get_value().output: - yield out + if not llm_cache_value: + raise ValueError(f"Cache value not found for key: {llm_cache_key}") + outputs = cast(List[ModelOutput], llm_cache_value.get_value().output) + for out in outputs: + yield cast(ModelOutput, out) class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]): @@ -63,6 +71,7 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]): """ def __init__(self, cache_manager: CacheManager, **kwargs) -> None: + """Create a new instance of CachedModelOperator.""" super().__init__(**kwargs) self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) @@ -78,14 +87,18 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]): """ cache_dict = _parse_cache_key_dict(input_value) llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict) - llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key) + llm_cache_value = await self._client.get(llm_cache_key) + if not llm_cache_value: + raise ValueError(f"Cache value not found for key: {llm_cache_key}") logger.info(f"llm_cache_value: {llm_cache_value}") - return llm_cache_value.get_value().output + return cast(ModelOutput, llm_cache_value.get_value().output) class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]): - """ - A branch operator that decides whether to use cached data or to process data using the model. + """Branch operator for model processing with cache support. + + A branch operator that decides whether to use cached data or to process data using + the model. Args: cache_manager (CacheManager): The cache manager for managing cache operations. @@ -101,6 +114,7 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]): cache_task_name: str, **kwargs, ): + """Create a new instance of ModelCacheBranchOperator.""" super().__init__(branches=None, **kwargs) self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) @@ -110,10 +124,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]): async def branches( self, ) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]: - """Defines branch logic based on cache availability. + """Branch logic based on cache availability. + + Defines branch logic based on cache availability. Returns: - Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names. + Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping + branch functions to task names. """ async def check_cache_true(input_value: ModelRequest) -> bool: @@ -124,12 +141,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]): cache_key: LLMCacheKey = self._client.new_key(**cache_dict) cache_value = await self._client.get(cache_key) logger.debug( - f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}" + f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: " + f"{cache_value}" ) await self.current_dag_context.save_to_share_data( _LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True ) - return True if cache_value else False + return bool(cache_value) async def check_cache_false(input_value: ModelRequest): # Inverse of check_cache_true @@ -152,22 +170,25 @@ class ModelStreamSaveCacheOperator( """ def __init__(self, cache_manager: CacheManager, **kwargs): + """Create a new instance of ModelStreamSaveCacheOperator.""" self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) super().__init__(**kwargs) - async def transform_stream( - self, input_value: AsyncIterator[ModelOutput] - ) -> AsyncIterator[ModelOutput]: - """Transforms the input stream by saving the outputs to cache. + async def transform_stream(self, input_value: AsyncIterator[ModelOutput]): + """Save the stream of model outputs to cache. + + Transforms the input stream by saving the outputs to cache. Args: - input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs. + input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model + outputs. Returns: - AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache. + AsyncIterator[ModelOutput]: The same input iterator, but the outputs are + saved to cache. """ - llm_cache_key: LLMCacheKey = None + llm_cache_key: Optional[LLMCacheKey] = None outputs = [] async for out in input_value: if not llm_cache_key: @@ -190,12 +211,13 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]): """ def __init__(self, cache_manager: CacheManager, **kwargs): + """Create a new instance of ModelSaveCacheOperator.""" self._cache_manager = cache_manager self._client = LLMCacheClient(cache_manager) super().__init__(**kwargs) async def map(self, input_value: ModelOutput) -> ModelOutput: - """Saves a single model output to cache and returns it. + """Save model output to cache. Args: input_value (ModelOutput): The output from the model to be cached. @@ -213,7 +235,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]): def _parse_cache_key_dict(input_value: ModelRequest) -> Dict: - """Parses and extracts relevant fields from input to form a cache key dictionary. + """Parse and extract relevant fields from input to form a cache key dictionary. Args: input_value (Dict): The input dictionary containing model and prompt parameters. diff --git a/dbgpt/storage/cache/protocal/__init__.py b/dbgpt/storage/cache/protocal/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/dbgpt/storage/cache/protocol/__init__.py b/dbgpt/storage/cache/protocol/__init__.py new file mode 100644 index 000000000..60819756d --- /dev/null +++ b/dbgpt/storage/cache/protocol/__init__.py @@ -0,0 +1 @@ +"""Module for protocol.""" diff --git a/dbgpt/storage/cache/storage/__init__.py b/dbgpt/storage/cache/storage/__init__.py index e69de29bb..cf52f4edc 100644 --- a/dbgpt/storage/cache/storage/__init__.py +++ b/dbgpt/storage/cache/storage/__init__.py @@ -0,0 +1 @@ +"""Module for cache storage implementation.""" diff --git a/dbgpt/storage/cache/storage/base.py b/dbgpt/storage/cache/storage/base.py index 779552b98..675726f75 100644 --- a/dbgpt/storage/cache/storage/base.py +++ b/dbgpt/storage/cache/storage/base.py @@ -1,3 +1,4 @@ +"""Base cache storage class.""" import logging from abc import ABC, abstractmethod from collections import OrderedDict @@ -22,8 +23,7 @@ logger = logging.getLogger(__name__) @dataclass class StorageItem: - """ - A class representing a storage item. + """A class representing a storage item. This class encapsulates data related to a storage item, such as its length, the hash of the key, and the data for both the key and value. @@ -44,6 +44,7 @@ class StorageItem: def build_from( key_hash: bytes, key_data: bytes, value_data: bytes ) -> "StorageItem": + """Build a StorageItem from the provided key and value data.""" length = ( 32 + _get_object_bytes(key_hash) @@ -56,6 +57,7 @@ class StorageItem: @staticmethod def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem": + """Build a StorageItem from the provided key and value.""" key_hash = key.get_hash_bytes() key_data = key.serialize() value_data = value.serialize() @@ -105,6 +107,8 @@ class StorageItem: class CacheStorage(ABC): + """Base class for cache storage.""" + @abstractmethod def check_config( self, @@ -122,6 +126,7 @@ class CacheStorage(ABC): """ def support_async(self) -> bool: + """Check whether the storage support async operation.""" return False @abstractmethod @@ -135,7 +140,8 @@ class CacheStorage(ABC): cache_config (Optional[CacheConfig]): Cache config Returns: - Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None. + Optional[StorageItem]: The storage item retrieved according to key. If + cache key not exist, return None. """ async def aget( @@ -148,7 +154,8 @@ class CacheStorage(ABC): cache_config (Optional[CacheConfig]): Cache config Returns: - Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None. + Optional[StorageItem]: The storage item of bytes retrieved according to + key. If cache key not exist, return None. """ raise NotImplementedError @@ -184,8 +191,11 @@ class CacheStorage(ABC): class MemoryCacheStorage(CacheStorage): + """A simple in-memory cache storage implementation.""" + def __init__(self, max_memory_mb: int = 256): - self.cache = OrderedDict() + """Create a new instance of MemoryCacheStorage.""" + self.cache: OrderedDict = OrderedDict() self.max_memory = max_memory_mb * 1024 * 1024 self.current_memory_usage = 0 @@ -194,6 +204,7 @@ class MemoryCacheStorage(CacheStorage): cache_config: Optional[CacheConfig] = None, raise_error: Optional[bool] = True, ) -> bool: + """Check whether the CacheConfig is legal.""" if ( cache_config and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH @@ -208,10 +219,11 @@ class MemoryCacheStorage(CacheStorage): def get( self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None ) -> Optional[StorageItem]: + """Retrieve a storage item from the cache using the provided key.""" self.check_config(cache_config, raise_error=True) # Exact match retrieval key_hash = hash(key) - item: StorageItem = self.cache.get(key_hash) + item: Optional[StorageItem] = self.cache.get(key_hash) logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}") if not item: @@ -226,6 +238,7 @@ class MemoryCacheStorage(CacheStorage): value: CacheValue[V], cache_config: Optional[CacheConfig] = None, ) -> None: + """Set a value in the cache for the provided key.""" key_hash = hash(key) item = StorageItem.build_from_kv(key, value) # Calculate memory size of the new entry @@ -242,6 +255,7 @@ class MemoryCacheStorage(CacheStorage): def exists( self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None ) -> bool: + """Check if the key exists in the cache.""" return self.get(key, cache_config) is not None def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None): diff --git a/dbgpt/storage/cache/storage/disk/__init__.py b/dbgpt/storage/cache/storage/disk/__init__.py index e69de29bb..0310c77fd 100644 --- a/dbgpt/storage/cache/storage/disk/__init__.py +++ b/dbgpt/storage/cache/storage/disk/__init__.py @@ -0,0 +1 @@ +"""Disk cache storage implementation.""" diff --git a/dbgpt/storage/cache/storage/disk/disk_storage.py b/dbgpt/storage/cache/storage/disk/disk_storage.py index fa1c01003..0cdf5df83 100644 --- a/dbgpt/storage/cache/storage/disk/disk_storage.py +++ b/dbgpt/storage/cache/storage/disk/disk_storage.py @@ -1,3 +1,7 @@ +"""Disk storage for cache. + +Implement the cache storage using rocksdb. +""" import logging from typing import Optional @@ -11,14 +15,14 @@ from dbgpt.core.interface.cache import ( RetrievalPolicy, V, ) -from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem + +from ..base import CacheStorage, StorageItem logger = logging.getLogger(__name__) -def db_options( - mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2 -): +def db_options(mem_table_buffer_mb: int = 256, background_threads: int = 2): + """Create rocksdb options.""" opt = Options() # create table opt.create_if_missing(True) @@ -42,9 +46,10 @@ def db_options( class DiskCacheStorage(CacheStorage): - def __init__( - self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256 - ) -> None: + """Disk cache storage using rocksdb.""" + + def __init__(self, persist_dir: str, mem_table_buffer_mb: int = 256) -> None: + """Create a new instance of DiskCacheStorage.""" super().__init__() self.db: Rdict = Rdict( persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb) @@ -55,6 +60,7 @@ class DiskCacheStorage(CacheStorage): cache_config: Optional[CacheConfig] = None, raise_error: Optional[bool] = True, ) -> bool: + """Check whether the CacheConfig is legal.""" if ( cache_config and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH @@ -69,6 +75,7 @@ class DiskCacheStorage(CacheStorage): def get( self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None ) -> Optional[StorageItem]: + """Retrieve a storage item from the cache using the provided key.""" self.check_config(cache_config, raise_error=True) # Exact match retrieval @@ -86,6 +93,7 @@ class DiskCacheStorage(CacheStorage): value: CacheValue[V], cache_config: Optional[CacheConfig] = None, ) -> None: + """Set a value in the cache for the provided key.""" item = StorageItem.build_from_kv(key, value) key_hash = item.key_hash self.db[key_hash] = item.serialize() diff --git a/dbgpt/storage/cache/storage/tests/test_storage.py b/dbgpt/storage/cache/storage/tests/test_storage.py index 1ba2545f0..159a57fe7 100644 --- a/dbgpt/storage/cache/storage/tests/test_storage.py +++ b/dbgpt/storage/cache/storage/tests/test_storage.py @@ -1,5 +1,3 @@ -import pytest - from dbgpt.util.memory_utils import _get_object_bytes from ..base import StorageItem diff --git a/dbgpt/storage/chat_history/__init__.py b/dbgpt/storage/chat_history/__init__.py index 8b1378917..fec681a2c 100644 --- a/dbgpt/storage/chat_history/__init__.py +++ b/dbgpt/storage/chat_history/__init__.py @@ -1 +1,19 @@ +"""Module of chat history.""" +from .chat_history_db import ( # noqa: F401 + ChatHistoryDao, + ChatHistoryEntity, + ChatHistoryMessageEntity, +) +from .storage_adapter import ( # noqa: F401 + DBMessageStorageItemAdapter, + DBStorageConversationItemAdapter, +) + +__ALL__ = [ + "ChatHistoryEntity", + "ChatHistoryMessageEntity", + "ChatHistoryDao", + "DBStorageConversationItemAdapter", + "DBMessageStorageItemAdapter", +] diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py index 8faaca8b2..e663e2bba 100644 --- a/dbgpt/storage/chat_history/chat_history_db.py +++ b/dbgpt/storage/chat_history/chat_history_db.py @@ -1,12 +1,15 @@ +"""Chat history database model.""" from datetime import datetime from typing import Optional from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint -from dbgpt.storage.metadata import BaseDao, Model +from ..metadata import BaseDao, Model class ChatHistoryEntity(Model): + """Chat history entity.""" + __tablename__ = "chat_history" __table_args__ = (UniqueConstraint("conv_uid", name="uk_conv_uid"),) id = Column( @@ -14,13 +17,16 @@ class ChatHistoryEntity(Model): ) conv_uid = Column( String(255), - # Change from False to True, the alembic migration will fail, so we use UniqueConstraint to replace it + # Change from False to True, the alembic migration will fail, so we use + # UniqueConstraint to replace it unique=False, nullable=False, comment="Conversation record unique id", ) chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode") - summary = Column(String(255), nullable=False, comment="Conversation record summary") + summary = Column( + Text(length=2**31 - 1), nullable=False, comment="Conversation record summary" + ) user_name = Column(String(255), nullable=True, comment="interlocutor") messages = Column( Text(length=2**31 - 1), nullable=True, comment="Conversation details" @@ -38,6 +44,8 @@ class ChatHistoryEntity(Model): class ChatHistoryMessageEntity(Model): + """Chat history message entity.""" + __tablename__ = "chat_history_message" __table_args__ = ( UniqueConstraint("conv_uid", "index", name="uk_conversation_message"), @@ -61,9 +69,12 @@ class ChatHistoryMessageEntity(Model): class ChatHistoryDao(BaseDao): + """Chat history dao.""" + def list_last_20( self, user_name: Optional[str] = None, sys_code: Optional[str] = None ): + """Retrieve the last 20 chat history records.""" session = self.get_raw_session() chat_history = session.query(ChatHistoryEntity) if user_name: @@ -78,6 +89,7 @@ class ChatHistoryDao(BaseDao): return result def raw_update(self, entity: ChatHistoryEntity): + """Update the chat history record.""" session = self.get_raw_session() try: updated = session.merge(entity) @@ -87,6 +99,7 @@ class ChatHistoryDao(BaseDao): session.close() def update_message_by_uid(self, message: str, conv_uid: str): + """Update the chat history record.""" session = self.get_raw_session() try: chat_history = session.query(ChatHistoryEntity) @@ -97,7 +110,8 @@ class ChatHistoryDao(BaseDao): finally: session.close() - def raw_delete(self, conv_uid: int): + def raw_delete(self, conv_uid: str): + """Delete the chat history record.""" if conv_uid is None: raise Exception("conv_uid is None") with self.session() as session: @@ -106,5 +120,6 @@ class ChatHistoryDao(BaseDao): chat_history.delete() def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]: + """Retrieve the chat history record by conv_uid.""" with self.session(commit=False) as session: return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first() diff --git a/dbgpt/storage/chat_history/storage_adapter.py b/dbgpt/storage/chat_history/storage_adapter.py index 93302ef89..eb2b3412f 100644 --- a/dbgpt/storage/chat_history/storage_adapter.py +++ b/dbgpt/storage/chat_history/storage_adapter.py @@ -1,5 +1,7 @@ +"""Adapter for chat history storage.""" + import json -from typing import Dict, List, Type +from typing import Dict, List, Optional, Type from sqlalchemy.orm import Session @@ -20,16 +22,23 @@ from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity class DBStorageConversationItemAdapter( StorageItemAdapter[StorageConversation, ChatHistoryEntity] ): + """Adapter for chat history storage.""" + def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity: + """Convert to storage format.""" message_ids = ",".join(item.message_ids) messages = None if not item.save_message_independent and item.messages: message_dict_list = [_conversation_to_dict(item)] messages = json.dumps(message_dict_list, ensure_ascii=False) + summary = item.summary + latest_user_message = item.get_latest_user_message() + if not summary and latest_user_message is not None: + summary = latest_user_message.content return ChatHistoryEntity( conv_uid=item.conv_uid, chat_mode=item.chat_mode, - summary=item.summary or item.get_latest_user_message().content, + summary=summary, user_name=item.user_name, # We not save messages to chat_history table in new design messages=messages, @@ -38,25 +47,25 @@ class DBStorageConversationItemAdapter( ) def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation: + """Convert from storage format.""" message_ids = model.message_ids.split(",") if model.message_ids else [] old_conversations: List[Dict] = ( - json.loads(model.messages) if model.messages else [] + json.loads(model.messages) if model.messages else [] # type: ignore ) old_messages = [] save_message_independent = True if old_conversations: - # Load old messages from old conversations, in old design, we save messages to chat_history table + # Load old messages from old conversations, in old design, we save messages + # to chat_history table save_message_independent = False - old_messages: List[BaseMessage] = _parse_old_conversations( - old_conversations - ) + old_messages = _parse_old_conversations(old_conversations) return StorageConversation( - conv_uid=model.conv_uid, - chat_mode=model.chat_mode, - summary=model.summary, - user_name=model.user_name, + conv_uid=model.conv_uid, # type: ignore + chat_mode=model.chat_mode, # type: ignore + summary=model.summary, # type: ignore + user_name=model.user_name, # type: ignore message_ids=message_ids, - sys_code=model.sys_code, + sys_code=model.sys_code, # type: ignore save_message_independent=save_message_independent, messages=old_messages, ) @@ -64,10 +73,11 @@ class DBStorageConversationItemAdapter( def get_query_for_identifier( self, storage_format: Type[ChatHistoryEntity], - resource_id: ConversationIdentifier, + resource_id: ConversationIdentifier, # type: ignore **kwargs, ): - session: Session = kwargs.get("session") + """Get query for identifier.""" + session: Optional[Session] = kwargs.get("session") if session is None: raise Exception("session is None") return session.query(ChatHistoryEntity).filter( @@ -78,7 +88,10 @@ class DBStorageConversationItemAdapter( class DBMessageStorageItemAdapter( StorageItemAdapter[MessageStorageItem, ChatHistoryMessageEntity] ): + """Adapter for chat history message storage.""" + def to_storage_format(self, item: MessageStorageItem) -> ChatHistoryMessageEntity: + """Convert to storage format.""" round_index = item.message_detail.get("round_index", 0) message_detail = json.dumps(item.message_detail, ensure_ascii=False) return ChatHistoryMessageEntity( @@ -91,22 +104,26 @@ class DBMessageStorageItemAdapter( def from_storage_format( self, model: ChatHistoryMessageEntity ) -> MessageStorageItem: + """Convert from storage format.""" message_detail = ( - json.loads(model.message_detail) if model.message_detail else {} + json.loads(model.message_detail) # type: ignore + if model.message_detail + else {} ) return MessageStorageItem( - conv_uid=model.conv_uid, - index=model.index, + conv_uid=model.conv_uid, # type: ignore + index=model.index, # type: ignore message_detail=message_detail, ) def get_query_for_identifier( self, storage_format: Type[ChatHistoryMessageEntity], - resource_id: MessageIdentifier, + resource_id: MessageIdentifier, # type: ignore **kwargs, ): - session: Session = kwargs.get("session") + """Get query for identifier.""" + session: Optional[Session] = kwargs.get("session") if session is None: raise Exception("session is None") return session.query(ChatHistoryMessageEntity).filter( diff --git a/dbgpt/storage/chat_history/store_type/__init__.py b/dbgpt/storage/chat_history/store_type/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/dbgpt/storage/chat_history/store_type/duckdb_history.py b/dbgpt/storage/chat_history/store_type/duckdb_history.py deleted file mode 100644 index 81644f346..000000000 --- a/dbgpt/storage/chat_history/store_type/duckdb_history.py +++ /dev/null @@ -1,182 +0,0 @@ -import json -import os -from typing import Dict, List, Optional - -import duckdb - -from dbgpt._private.config import Config -from dbgpt.configs.model_config import PILOT_PATH -from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory - -from ..base import MemoryStoreType - -default_db_path = os.path.join(PILOT_PATH, "message") -duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") -table_name = "chat_history" - -CFG = Config() - - -class DuckdbHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.DuckDb.value - - def __init__(self, chat_session_id: str): - self.chat_seesion_id = chat_session_id - os.makedirs(default_db_path, exist_ok=True) - self.connect = duckdb.connect(duckdb_path) - self.__init_chat_history_tables() - - def __init_chat_history_tables(self): - # 检查表是否存在 - result = self.connect.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] - ).fetchall() - - if not result: - # 如果表不存在,则创建新表 - self.connect.execute( - "CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)" - ) - self.connect.execute("CREATE SEQUENCE seq_id START 1;") - - def __get_messages_by_conv_uid(self, conv_uid: str): - cursor = self.connect.cursor() - cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) - content = cursor.fetchone() - if content: - return content[0] - else: - return None - - def messages(self) -> List[OnceConversation]: - context = self.__get_messages_by_conv_uid(self.chat_seesion_id) - if context: - conversations: List[OnceConversation] = json.loads(context) - return conversations - return [] - - def create(self, chat_mode, summary: str, user_name: str) -> None: - try: - cursor = self.connect.cursor() - cursor.execute( - "INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)", - [self.chat_seesion_id, chat_mode, summary, user_name, "", ""], - ) - cursor.commit() - self.connect.commit() - except Exception as e: - print("init create conversation log error!" + str(e)) - - def append(self, once_message: OnceConversation) -> None: - context = self.__get_messages_by_conv_uid(self.chat_seesion_id) - conversations: List[OnceConversation] = [] - if context: - conversations = json.loads(context) - conversations.append(_conversation_to_dict(once_message)) - cursor = self.connect.cursor() - if context: - cursor.execute( - "UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id], - ) - else: - cursor.execute( - "INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)", - [ - self.chat_seesion_id, - once_message.chat_mode, - once_message.get_latest_user_message().content, - once_message.user_name, - once_message.sys_code, - json.dumps(conversations, ensure_ascii=False), - ], - ) - cursor.commit() - self.connect.commit() - - def update(self, messages: List[OnceConversation]) -> None: - cursor = self.connect.cursor() - cursor.execute( - "UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(messages, ensure_ascii=False), self.chat_seesion_id], - ) - cursor.commit() - self.connect.commit() - - def clear(self) -> None: - cursor = self.connect.cursor() - cursor.execute( - "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] - ) - cursor.commit() - self.connect.commit() - - def delete(self) -> bool: - cursor = self.connect.cursor() - cursor.execute( - "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] - ) - cursor.commit() - return True - - def conv_info(self, conv_uid: str = None) -> None: - cursor = self.connect.cursor() - cursor.execute( - "SELECT * FROM chat_history where conv_uid=? ", - [conv_uid], - ) - # 获取查询结果字段名 - fields = [field[0] for field in cursor.description] - - for row in cursor.fetchone(): - row_dict = {} - for i, field in enumerate(fields): - row_dict[field] = row[i] - return row_dict - - return {} - - def get_messages(self) -> List[OnceConversation]: - cursor = self.connect.cursor() - cursor.execute( - "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id] - ) - context = cursor.fetchone() - if context: - if context[0]: - return json.loads(context[0]) - return None - - @staticmethod - def conv_list( - user_name: Optional[str] = None, sys_code: Optional[str] = None - ) -> List[Dict]: - if os.path.isfile(duckdb_path): - cursor = duckdb.connect(duckdb_path).cursor() - query = "SELECT * FROM chat_history" - params = [] - conditions = [] - if user_name: - conditions.append("user_name = ?") - params.append(user_name) - if sys_code: - conditions.append("sys_code = ?") - params.append(sys_code) - - if conditions: - query += " WHERE " + " AND ".join(conditions) - - query += " ORDER BY id DESC LIMIT 20" - cursor.execute(query, params) - fields = [field[0] for field in cursor.description] - data = [] - for row in cursor.fetchall(): - row_dict = {} - for i, field in enumerate(fields): - row_dict[field] = row[i] - data.append(row_dict) - - return data - - return [] diff --git a/dbgpt/storage/chat_history/store_type/file_history.py b/dbgpt/storage/chat_history/store_type/file_history.py deleted file mode 100644 index efa9f4e69..000000000 --- a/dbgpt/storage/chat_history/store_type/file_history.py +++ /dev/null @@ -1,50 +0,0 @@ -import datetime -import json -import os -from pathlib import Path -from typing import List - -from dbgpt._private.config import Config -from dbgpt.core.interface.message import ( - OnceConversation, - _conversation_from_dict, - _conversations_to_dict, -) -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType - -CFG = Config() - - -class FileHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.File.value - - def __init__(self, chat_session_id: str): - now = datetime.datetime.now() - date_string = now.strftime("%Y%m%d") - path: str = f"{CFG.message_dir}/{date_string}" - os.makedirs(path, exist_ok=True) - - dir_path = Path(path) - self.file_path = Path(dir_path / f"{chat_session_id}.json") - if not self.file_path.exists(): - self.file_path.touch() - self.file_path.write_text(json.dumps([])) - - def messages(self) -> List[OnceConversation]: - items = json.loads(self.file_path.read_text()) - history: List[OnceConversation] = [] - for onece in items: - messages = _conversation_from_dict(onece) - history.append(messages) - return history - - def append(self, once_message: OnceConversation) -> None: - historys = self.messages() - historys.append(once_message) - self.file_path.write_text( - json.dumps(_conversations_to_dict(historys), ensure_ascii=False, indent=4), - encoding="UTF-8", - ) - - def clear(self) -> None: - self.file_path.write_text(json.dumps([])) diff --git a/dbgpt/storage/chat_history/store_type/mem_history.py b/dbgpt/storage/chat_history/store_type/mem_history.py deleted file mode 100644 index 3c5264627..000000000 --- a/dbgpt/storage/chat_history/store_type/mem_history.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import List - -from dbgpt._private.config import Config -from dbgpt.core.interface.message import OnceConversation -from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType -from dbgpt.util.custom_data_structure import FixedSizeDict - -CFG = Config() - - -class MemHistoryMemory(BaseChatHistoryMemory): - store_type: str = MemoryStoreType.Memory.value - - histroies_map = FixedSizeDict(100) - - def __init__(self, chat_session_id: str): - self.chat_seesion_id = chat_session_id - self.histroies_map.update({chat_session_id: []}) - - def messages(self) -> List[OnceConversation]: - return self.histroies_map.get(self.chat_seesion_id) - - def append(self, once_message: OnceConversation) -> None: - self.histroies_map.get(self.chat_seesion_id).append(once_message) - - def clear(self) -> None: - self.histroies_map.pop(self.chat_seesion_id) diff --git a/dbgpt/storage/metadata/__init__.py b/dbgpt/storage/metadata/__init__.py index 8660866d9..9c0f6e88a 100644 --- a/dbgpt/storage/metadata/__init__.py +++ b/dbgpt/storage/metadata/__init__.py @@ -1,6 +1,7 @@ -from dbgpt.storage.metadata._base_dao import BaseDao -from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory -from dbgpt.storage.metadata.db_manager import ( +"""Module for handling metadata storage.""" +from dbgpt.storage.metadata._base_dao import BaseDao # noqa: F401 +from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory # noqa: F401 +from dbgpt.storage.metadata.db_manager import ( # noqa: F401 BaseModel, DatabaseManager, Model, diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 02b6f0291..1b84b34b3 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union from sqlalchemy.orm.session import Session @@ -44,6 +44,7 @@ class BaseDao(Generic[T, REQ, RES]): self, db_manager: Optional[DatabaseManager] = None, ) -> None: + """Create a BaseDao instance.""" self._db_manager = db_manager or db def get_raw_session(self) -> Session: @@ -52,9 +53,7 @@ class BaseDao(Generic[T, REQ, RES]): Your should commit or rollback the session manually. We suggest you use :meth:`session` instead. - Example: - .. code-block:: python user = User(name="Edward Snowden") @@ -63,13 +62,14 @@ class BaseDao(Generic[T, REQ, RES]): session.commit() session.close() """ - return self._db_manager._session() + return self._db_manager._session() # type: ignore @contextmanager - def session(self, commit: Optional[bool] = True) -> Session: + def session(self, commit: Optional[bool] = True) -> Iterator[Session]: """Provide a transactional scope around a series of operations. - If raise an exception, the session will be roll back automatically, otherwise it will be committed. + If raise an exception, the session will be roll back automatically, otherwise + it will be committed. Example: .. code-block:: python @@ -78,7 +78,8 @@ class BaseDao(Generic[T, REQ, RES]): session.query(User).filter(User.name == "Edward Snowden").first() Args: - commit (Optional[bool], optional): Whether to commit the session. Defaults to True. + commit (Optional[bool], optional): Whether to commit the session. Defaults + to True. Returns: Session: A session object. @@ -147,7 +148,8 @@ class BaseDao(Generic[T, REQ, RES]): session.add(entry) req = self.to_request(entry) session.commit() - return self.get_one(req) + res = self.get_one(req) + return res # type: ignore def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES: """Update an entity object. @@ -163,11 +165,14 @@ class BaseDao(Generic[T, REQ, RES]): entry = query.first() if entry is None: raise Exception("Invalid request") - for key, value in update_request.dict().items(): + for key, value in update_request.dict().items(): # type: ignore if value is not None: setattr(entry, key, value) session.merge(entry) - return self.get_one(self.to_request(entry)) + res = self.get_one(self.to_request(entry)) + if not res: + raise Exception("Update failed") + return res def delete(self, query_request: QUERY_SPEC) -> None: """Delete an entity object. @@ -179,7 +184,8 @@ class BaseDao(Generic[T, REQ, RES]): result_list = self._get_entity_list(session, query_request) if len(result_list) != 1: raise ValueError( - f"Delete request should return one result, but got {len(result_list)}" + f"Delete request should return one result, but got " + f"{len(result_list)}" ) session.delete(result_list[0]) @@ -241,11 +247,11 @@ class BaseDao(Generic[T, REQ, RES]): query = self._create_query_object(session, query_request) total_count = query.count() items = query.offset((page - 1) * page_size).limit(page_size) - items = [self.to_response(item) for item in items] + res_items = [self.to_response(item) for item in items] total_pages = (total_count + page_size - 1) // page_size return PaginationResult( - items=items, + items=res_items, total_count=total_count, total_pages=total_pages, page=page, @@ -273,4 +279,4 @@ class BaseDao(Generic[T, REQ, RES]): if isinstance(value, (list, tuple, dict, set)): continue query = query.filter(getattr(model_cls, key) == value) - return query + return query # type: ignore diff --git a/dbgpt/storage/metadata/db_factory.py b/dbgpt/storage/metadata/db_factory.py index c288a149d..c128a540e 100644 --- a/dbgpt/storage/metadata/db_factory.py +++ b/dbgpt/storage/metadata/db_factory.py @@ -1,19 +1,26 @@ +"""UnifiedDBManagerFactory is a factory class to create a DatabaseManager instance.""" from dbgpt.component import BaseComponent, ComponentType, SystemApp from .db_manager import DatabaseManager class UnifiedDBManagerFactory(BaseComponent): + """UnfiedDBManagerFactory class.""" + name = ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY + """The name of the factory.""" def __init__(self, system_app: SystemApp, db_manager: DatabaseManager): + """Create a UnifiedDBManagerFactory instance.""" super().__init__(system_app) self._db_manager = db_manager def init_app(self, system_app: SystemApp): + """Initialize the factory with the system app.""" pass def create(self) -> DatabaseManager: + """Create a DatabaseManager instance.""" if not self._db_manager: raise RuntimeError("db_manager is not initialized") if not self._db_manager.is_initialized: diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index 7f52939c1..663ee779d 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -1,8 +1,9 @@ +"""The database manager.""" from __future__ import annotations import logging from contextlib import contextmanager -from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union +from typing import ClassVar, Dict, Generic, Iterator, Optional, Type, TypeVar, Union from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm from sqlalchemy.orm import ( @@ -21,19 +22,20 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound="BaseModel") -class _QueryObject: - """The query object.""" - - def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]): - return model_cls.query_class( - model_cls, session=model_cls.__db_manager__._session() - ) +# class _QueryObject: +# """The query object.""" +# +# def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]): +# return model_cls.query_class( +# model_cls, session=model_cls.__db_manager__._session() +# ) +# class BaseQuery(orm.Query): - def paginate_query( - self, page: Optional[int] = 1, per_page: Optional[int] = 20 - ) -> PaginationResult: + """Base query class.""" + + def paginate_query(self, page: int = 1, per_page: int = 20) -> PaginationResult: """Paginate the query. Example: @@ -56,10 +58,10 @@ class BaseQuery(orm.Query): ) print(pagination) - Args: page (Optional[int], optional): The page number. Defaults to 1. - per_page (Optional[int], optional): The number of items per page. Defaults to 20. + per_page (Optional[int], optional): The number of items per page. Defaults + to 20. Returns: PaginationResult: The pagination result. """ @@ -100,7 +102,6 @@ class DatabaseManager: """The database manager. Examples: - .. code-block:: python from urllib.parse import quote_plus as urlquote, quote @@ -161,6 +162,7 @@ class DatabaseManager: Query = BaseQuery def __init__(self): + """Create a DatabaseManager.""" self._db_url = None self._base: DeclarativeMeta = self._make_declarative_base(_Model) self._engine: Optional[Engine] = None @@ -169,12 +171,12 @@ class DatabaseManager: @property def Model(self) -> _Model: """Get the declarative base.""" - return self._base + return self._base # type: ignore @property def metadata(self) -> MetaData: """Get the metadata.""" - return self.Model.metadata + return self.Model.metadata # type: ignore @property def engine(self): @@ -183,11 +185,11 @@ class DatabaseManager: @property def is_initialized(self) -> bool: - """Whether the database manager is initialized.""" "" + """Whether the database manager is initialized.""" return self._engine is not None and self._session is not None @contextmanager - def session(self, commit: Optional[bool] = True) -> Session: + def session(self, commit: Optional[bool] = True) -> Iterator[Session]: """Get the session with context manager. This context manager handles the lifecycle of a SQLAlchemy session. @@ -199,7 +201,6 @@ class DatabaseManager: read and write operations. Examples: - .. code-block:: python # For write operations (insert, update, delete): @@ -211,7 +212,8 @@ class DatabaseManager: # For read-only operations: with db.session(commit=False) as session: user = session.query(User).filter_by(name="John Doe").first() - # session.commit() is NOT called, as it's unnecessary for read operations + # session.commit() is NOT called, as it's unnecessary for read + # operations Args: commit (Optional[bool], optional): Whether to commit the session. @@ -237,16 +239,15 @@ class DatabaseManager: with the ORM object are complete. c. Re-bind the instance to a new session if further interaction is required after the session is closed. - """ if not self.is_initialized: raise RuntimeError("The database manager is not initialized.") - session = self._session() + session = self._session() # type: ignore try: yield session if commit: session.commit() - except: + except Exception: session.rollback() raise finally: @@ -266,10 +267,10 @@ class DatabaseManager: if not isinstance(model, DeclarativeMeta): model = declarative_base(cls=model, name="Model") if not getattr(model, "query_class", None): - model.query_class = self.Query + model.query_class = self.Query # type: ignore # model.query = _QueryObject() - model.__db_manager__ = self - return model + model.__db_manager__ = self # type: ignore + return model # type: ignore def init_db( self, @@ -284,11 +285,14 @@ class DatabaseManager: Args: db_url (Union[str, URL]): The database url. - engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. + engine_args (Optional[Dict], optional): The engine arguments. Defaults to + None. base (Optional[DeclarativeMeta]): The base class. Defaults to None. query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. - override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False. - session_options (Optional[Dict], optional): The session options. Defaults to None. + override_query_class (Optional[bool], optional): Whether to override the + query class. Defaults to False. + session_options (Optional[Dict], optional): The session options. Defaults + to None. """ if session_options is None: session_options = {} @@ -309,8 +313,8 @@ class DatabaseManager: session_options.setdefault("query_cls", self.Query) session_factory = sessionmaker(bind=self._engine, **session_options) # self._session = scoped_session(session_factory) - self._session = session_factory - self._base.metadata.bind = self._engine + self._session = session_factory # type: ignore + self._base.metadata.bind = self._engine # type: ignore def init_default_db( self, @@ -333,24 +337,28 @@ class DatabaseManager: base (Optional[DeclarativeMeta]): The base class. Defaults to None. """ if not engine_args: - engine_args = {} - # Pool class - engine_args["poolclass"] = QueuePool - # The number of connections to keep open inside the connection pool. - engine_args["pool_size"] = 10 - # The maximum overflow size of the pool when the number of connections be used in the pool is exceeded( - # pool_size). - engine_args["max_overflow"] = 20 - # The number of seconds to wait before giving up on getting a connection from the pool. - engine_args["pool_timeout"] = 30 - # Recycle the connection if it has been idle for this many seconds. - engine_args["pool_recycle"] = 3600 - # Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout. - engine_args["pool_pre_ping"] = True + engine_args = { + # Pool class + "poolclass": QueuePool, + # The number of connections to keep open inside the connection pool. + "pool_size": 10, + # The maximum overflow size of the pool when the number of connections + # be used in the pool is exceeded(pool_size). + "max_overflow": 20, + # The number of seconds to wait before giving up on getting a connection + # from the pool. + "pool_timeout": 30, + # Recycle the connection if it has been idle for this many seconds. + "pool_recycle": 3600, + # Enable the connection pool “pre-ping” feature that tests connections + # for liveness upon each checkout. + "pool_pre_ping": True, + } self.init_db(f"sqlite:///{sqlite_path}", engine_args, base) def create_all(self): + """Create all tables.""" self.Model.metadata.create_all(self._engine) @staticmethod @@ -364,9 +372,7 @@ class DatabaseManager: """Build the database manager from the db_url_or_db. Examples: - Build from the database url. - .. code-block:: python from dbgpt.storage.metadata import DatabaseManager @@ -389,16 +395,19 @@ class DatabaseManager: print(User.query.filter(User.name == "test").all()) Args: - db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager. - engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. + db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the + database manager. + engine_args (Optional[Dict], optional): The engine arguments. Defaults to + None. base (Optional[DeclarativeMeta]): The base class. Defaults to None. query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. - override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False. + override_query_class (Optional[bool], optional): Whether to override the + query class. Defaults to False. Returns: DatabaseManager: The database manager. """ - if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL): + if isinstance(db_url_or_db, (str, URL)): db_manager = DatabaseManager() db_manager.init_db( db_url_or_db, engine_args, base, query_class, override_query_class @@ -408,7 +417,8 @@ class DatabaseManager: return db_url_or_db else: raise ValueError( - f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}" + f"db_url_or_db should be either url or a DatabaseManager, got " + f"{type(db_url_or_db)}" ) @@ -422,7 +432,6 @@ Examples: >>> with db.session() as session: ... session.query(...) ... - >>> from dbgpt.storage.metadata import db, Model >>> from urllib.parse import quote_plus as urlquote, quote >>> db_name = "dbgpt" @@ -430,7 +439,10 @@ Examples: >>> db_port = 3306 >>> user = "root" >>> password = "123456" - >>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}" + >>> url = ( + ... f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}" + ... f":{str(db_port)}/{db_name}" + ... ) >>> engine_args = { ... "pool_size": 10, ... "max_overflow": 20, @@ -460,18 +472,21 @@ class BaseCRUDMixin(Generic[T]): @classmethod def db(cls) -> DatabaseManager: """Get the database manager.""" - return cls.__db_manager__ + return cls.__db_manager__ # type: ignore class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]): """The base model class that includes CRUD convenience methods.""" __abstract__ = True + """Whether the model is abstract.""" def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]: - class CRUDMixin(BaseCRUDMixin[T], Generic[T]): - """Mixin that adds convenience methods for CRUD (create, read, update, delete)""" + """Create a model.""" + + class CRUDMixin(BaseCRUDMixin[T], Generic[T]): # type: ignore + """Mixin that adds convenience methods for CRUD.""" _db_manager: DatabaseManager = db_manager @@ -485,7 +500,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]: """Get the database manager.""" return cls._db_manager - class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): + class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): # type: ignore """Base model class that includes CRUD convenience methods.""" __abstract__ = True @@ -493,7 +508,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]: return _NewModel -Model = create_model(db) +Model: Type = create_model(db) def initialize_db( @@ -511,8 +526,10 @@ def initialize_db( db_name (str): The database name. engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. base (Optional[DeclarativeMeta]): The base class. Defaults to None. - try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False. - session_options (Optional[Dict], optional): The session options. Defaults to None. + try_to_create_db (Optional[bool], optional): Whether to try to create the + database. Defaults to False. + session_options (Optional[Dict], optional): The session options. Defaults to + None. Returns: DatabaseManager: The database manager. """ diff --git a/dbgpt/storage/metadata/db_storage.py b/dbgpt/storage/metadata/db_storage.py index 6b6e9716a..0db90c120 100644 --- a/dbgpt/storage/metadata/db_storage.py +++ b/dbgpt/storage/metadata/db_storage.py @@ -1,5 +1,6 @@ +"""Database storage implementation using SQLAlchemy.""" from contextlib import contextmanager -from typing import Dict, List, Optional, Type, Union +from typing import Dict, Iterator, List, Optional, Type, Union from sqlalchemy import URL from sqlalchemy.orm import DeclarativeMeta, Session @@ -17,8 +18,8 @@ from .db_manager import BaseModel, BaseQuery, DatabaseManager def _copy_public_properties(src: BaseModel, dest: BaseModel): - """Simple copy public properties from src to dest""" - for column in src.__table__.columns: + """Copy public properties from src to dest.""" + for column in src.__table__.columns: # type: ignore if column.name != "id": value = getattr(src, column.name) if value is not None: @@ -26,6 +27,8 @@ def _copy_public_properties(src: BaseModel, dest: BaseModel): class SQLAlchemyStorage(StorageInterface[T, BaseModel]): + """Database storage implementation using SQLAlchemy.""" + def __init__( self, db_url_or_db: Union[str, URL, DatabaseManager], @@ -36,6 +39,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]): base: Optional[DeclarativeMeta] = None, query_class=BaseQuery, ): + """Create a SQLAlchemyStorage instance.""" super().__init__(serializer=serializer, adapter=adapter) self.db_manager = DatabaseManager.build_from( db_url_or_db, engine_args, base, query_class @@ -43,16 +47,19 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]): self._model_class = model_class @contextmanager - def session(self) -> Session: + def session(self) -> Iterator[Session]: + """Return a session.""" with self.db_manager.session() as session: yield session def save(self, data: T) -> None: + """Save data to the storage.""" with self.session() as session: model_instance = self.adapter.to_storage_format(data) session.add(model_instance) def update(self, data: T) -> None: + """Update data in the storage.""" with self.session() as session: query = self.adapter.get_query_for_identifier( self._model_class, data.identifier, session=session @@ -66,6 +73,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]): return def save_or_update(self, data: T) -> None: + """Save or update data in the storage.""" with self.session() as session: query = self.adapter.get_query_for_identifier( self._model_class, data.identifier, session=session @@ -79,6 +87,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]): self.save(data) def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: + """Load data by identifier from the storage.""" with self.session() as session: query = self.adapter.get_query_for_identifier( self._model_class, resource_id, session=session @@ -89,6 +98,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]): return None def delete(self, resource_id: ResourceIdentifier) -> None: + """Delete data by identifier from the storage.""" with self.session() as session: query = self.adapter.get_query_for_identifier( self._model_class, resource_id, session=session diff --git a/dbgpt/storage/metadata/meta_data.py b/dbgpt/storage/metadata/meta_data.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/dbgpt/storage/schema.py b/dbgpt/storage/schema.py index 9ae8a67e4..1119922c4 100644 --- a/dbgpt/storage/schema.py +++ b/dbgpt/storage/schema.py @@ -1,14 +1,21 @@ +"""Database information class and database type enumeration.""" import os from enum import Enum +from typing import Optional class DbInfo: + """Database information class.""" + def __init__(self, name, is_file_db: bool = False): + """Create a new instance of DbInfo.""" self.name = name self.is_file_db = is_file_db class DBType(Enum): + """Database type enumeration.""" + Mysql = DbInfo("mysql") OCeanBase = DbInfo("oceanbase") DuckDb = DbInfo("duckdb", True) @@ -22,14 +29,24 @@ class DBType(Enum): Doris = DbInfo("doris") Hive = DbInfo("hive") - def value(self): + def value(self) -> str: + """Return the name of the database type.""" return self._value_.name - def is_file_db(self): + def is_file_db(self) -> bool: + """Return whether the database is a file database.""" return self._value_.is_file_db @staticmethod - def of_db_type(db_type: str): + def of_db_type(db_type: str) -> Optional["DBType"]: + """Return the database type of the given name. + + Args: + db_type (str): The name of the database type. + + Returns: + Optional[DBType]: The database type of the given name. + """ for item in DBType: if item.value() == db_type: return item @@ -37,7 +54,7 @@ class DBType(Enum): @staticmethod def parse_file_db_name_from_path(db_type: str, local_db_path: str): - """Parse out the database name of the embedded database from the file path""" + """Parse out the database name of the embedded database from the file path.""" base_name = os.path.basename(local_db_path) db_name = os.path.splitext(base_name)[0] if "." in db_name: diff --git a/dbgpt/storage/vector_store/__init__.py b/dbgpt/storage/vector_store/__init__.py index 40f508102..924680b39 100644 --- a/dbgpt/storage/vector_store/__init__.py +++ b/dbgpt/storage/vector_store/__init__.py @@ -1,3 +1,4 @@ +"""Vector Store Module.""" from typing import Any diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index 769a48b08..2de3c3223 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -1,12 +1,13 @@ +"""Vector store base class.""" import logging import math import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, List, Optional - -from pydantic import BaseModel, Field +from typing import List, Optional +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core import Embeddings from dbgpt.rag.chunk import Chunk logger = logging.getLogger(__name__) @@ -15,6 +16,11 @@ logger = logging.getLogger(__name__) class VectorStoreConfig(BaseModel): """Vector store config.""" + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + name: str = Field( default="dbgpt_collection", description="The name of vector store, if not set, will use the default name.", @@ -28,7 +34,7 @@ class VectorStoreConfig(BaseModel): description="The password of vector store, if not set, will use the default " "password.", ) - embedding_fn: Optional[Any] = Field( + embedding_fn: Optional[Embeddings] = Field( default=None, description="The embedding function of vector store, if not set, will use the " "default embedding function.", @@ -47,27 +53,31 @@ class VectorStoreConfig(BaseModel): class VectorStoreBase(ABC): - """base class for vector store database""" + """Vector store base class.""" @abstractmethod def load_document(self, chunks: List[Chunk]) -> List[str]: - """load document in vector database. + """Load document in vector database. + Args: - - chunks: document chunks. + chunks(List[Chunk]): document chunks. + Return: - - ids: chunks ids. + List[str]: chunk ids. """ - pass def load_document_with_limit( self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1 ) -> List[str]: - """load document in vector database with limit. + """Load document in vector database with specified limit. + Args: - chunks: document chunks. - max_chunks_once_load: Max number of chunks to load at once. - max_threads: Max number of threads to use. + chunks(List[Chunk]): Document chunks. + max_chunks_once_load(int): Max number of chunks to load at once. + max_threads(int): Max number of threads to use. + Return: + List[str]: Chunk ids. """ # Group the chunks into chunks of size max_chunks chunk_groups = [ @@ -96,13 +106,15 @@ class VectorStoreBase(ABC): return ids @abstractmethod - def similar_search(self, text, topk) -> List[Chunk]: - """similar search in vector database. + def similar_search(self, text: str, topk: int) -> List[Chunk]: + """Similar search in vector database. + Args: - - text: query text - - topk: topk + text(str): The query text. + topk(int): The number of similar documents to return. + Return: - - chunks: chunks. + List[Chunk]: The similar documents. """ pass @@ -110,38 +122,43 @@ class VectorStoreBase(ABC): def similar_search_with_scores( self, text, topk, score_threshold: float ) -> List[Chunk]: - """similar search in vector database with scores. + """Similar search with scores in vector database. + Args: - - text: query text - - topk: topk - - score_threshold: score_threshold: Optional, a floating point value between 0 to 1 + text(str): The query text. + topk(int): The number of similar documents to return. + score_threshold(int): score_threshold: Optional, a floating point value + between 0 to 1 Return: - - chunks: chunks. + List[Chunk]: The similar documents. """ - pass @abstractmethod def vector_name_exists(self) -> bool: - """is vector store name exist.""" + """Whether vector name exists.""" return False @abstractmethod - def delete_by_ids(self, ids): - """delete vector by ids. + def delete_by_ids(self, ids: str): + """Delete vectors by ids. + Args: - - ids: vector ids + ids(str): The ids of vectors to delete, separated by comma. """ @abstractmethod - def delete_vector_name(self, vector_name): - """delete vector name. + def delete_vector_name(self, vector_name: str): + """Delete vector by name. + Args: - - vector_name: vector store name + vector_name(str): The name of vector to delete. """ - pass def _normalization_vectors(self, vectors): - """normalization vectors to scale[0,1]""" + """Return L2-normalization vectors to scale[0,1]. + + Normalization vectors to scale[0,1]. + """ import numpy as np norm = np.linalg.norm(vectors) diff --git a/dbgpt/storage/vector_store/chroma_store.py b/dbgpt/storage/vector_store/chroma_store.py index cd5b76e71..93c4c239d 100644 --- a/dbgpt/storage/vector_store/chroma_store.py +++ b/dbgpt/storage/vector_store/chroma_store.py @@ -1,14 +1,18 @@ +"""Chroma vector store.""" import logging import os from typing import Any, List from chromadb import PersistentClient from chromadb.config import Settings -from pydantic import Field +from dbgpt._private.pydantic import Field from dbgpt.configs.model_config import PILOT_PATH + +# TODO: Recycle dependency on rag and storage from dbgpt.rag.chunk import Chunk -from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig + +from .base import VectorStoreBase, VectorStoreConfig logger = logging.getLogger(__name__) @@ -16,20 +20,28 @@ logger = logging.getLogger(__name__) class ChromaVectorConfig(VectorStoreConfig): """Chroma vector store config.""" + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + persist_path: str = Field( default=os.getenv("CHROMA_PERSIST_PATH", None), - description="The password of vector store, if not set, will use the default password.", + description="The password of vector store, if not set, will use the default " + "password.", ) collection_metadata: dict = Field( default=None, - description="the index metadata of vector store, if not set, will use the default metadata.", + description="the index metadata of vector store, if not set, will use the " + "default metadata.", ) class ChromaStore(VectorStoreBase): - """chroma database""" + """Chroma vector store.""" def __init__(self, vector_store_config: ChromaVectorConfig) -> None: + """Create a ChromaStore instance.""" from langchain.vectorstores import Chroma chroma_vector_config = vector_store_config.dict() @@ -59,6 +71,7 @@ class ChromaStore(VectorStoreBase): ) def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]: + """Search similar documents.""" logger.info("ChromaStore similar search") lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs) return [ @@ -67,14 +80,16 @@ class ChromaStore(VectorStoreBase): ] def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]: - """ + """Search similar documents with scores. + Chroma similar_search_with_score. Return docs and relevance scores in the range [0, 1]. Args: text(str): query text topk(int): return docs nums. Defaults to 4. - score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar. + score_threshold(float): score_threshold: Optional, a floating point value + between 0 to 1 to filter the resulting set of retrieved docs,0 is + dissimilar, 1 is most similar. """ logger.info("ChromaStore similar search with scores") docs_and_scores = ( @@ -87,8 +102,8 @@ class ChromaStore(VectorStoreBase): for doc, score in docs_and_scores ] - def vector_name_exists(self): - """is vector store name exist.""" + def vector_name_exists(self) -> bool: + """Whether vector name exists.""" logger.info(f"Check persist_dir: {self.persist_dir}") if not os.path.exists(self.persist_dir): return False @@ -98,6 +113,7 @@ class ChromaStore(VectorStoreBase): return len(files) > 0 def load_document(self, chunks: List[Chunk]) -> List[str]: + """Load document to vector store.""" logger.info("ChromaStore load document") texts = [chunk.content for chunk in chunks] metadatas = [chunk.metadata for chunk in chunks] @@ -105,14 +121,16 @@ class ChromaStore(VectorStoreBase): self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids) return ids - def delete_vector_name(self, vector_name): + def delete_vector_name(self, vector_name: str): + """Delete vector name.""" logger.info(f"chroma vector_name:{vector_name} begin delete...") self.vector_store_client.delete_collection() self._clean_persist_folder() return True def delete_by_ids(self, ids): - logger.info(f"begin delete chroma ids...") + """Delete vector by ids.""" + logger.info(f"begin delete chroma ids: {ids}") ids = ids.split(",") if len(ids) > 0: collection = self.vector_store_client._collection diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index 0c2ffa064..851f07337 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -1,19 +1,26 @@ +"""Connector for vector store.""" + import os -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional, Type, cast from dbgpt.rag.chunk import Chunk from dbgpt.storage import vector_store from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig -connector = {} +connector: Dict[str, Type] = {} class VectorStoreConnector: + """The connector for vector store. - """VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1. - 1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate) - 2.similar_search: similarity search from vector_store - 3.similar_search_with_scores: similarity search with similarity score from vector_store + VectorStoreConnector, can connect different vector db provided load document api_v1 + and similar search api_v1. + + 1.load_document:knowledge document source into vector store.(Chroma, Milvus, + Weaviate). + 2.similar_search: similarity search from vector_store. + 3.similar_search_with_scores: similarity search with similarity score from + vector_store code example: >>> from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -23,9 +30,12 @@ class VectorStoreConnector: """ def __init__( - self, vector_store_type: str, vector_store_config: VectorStoreConfig = None + self, + vector_store_type: str, + vector_store_config: Optional[VectorStoreConfig] = None, ) -> None: - """initialize vector store connector. + """Create a VectorStoreConnector instance. + Args: - vector_store_type: vector store type Milvus, Chroma, Weaviate - ctx: vector store config params. @@ -34,7 +44,7 @@ class VectorStoreConnector: self._register() if self._match(vector_store_type): - self.connector_class = connector.get(vector_store_type) + self.connector_class = connector[vector_store_type] else: raise Exception(f"Vector Store Type Not support. {0}", vector_store_type) @@ -44,11 +54,11 @@ class VectorStoreConnector: @classmethod def from_default( cls, - vector_store_type: str = None, + vector_store_type: Optional[str] = None, embedding_fn: Optional[Any] = None, vector_store_config: Optional[VectorStoreConfig] = None, ) -> "VectorStoreConnector": - """initialize default vector store connector.""" + """Initialize default vector store connector.""" vector_store_type = vector_store_type or os.getenv( "VECTOR_STORE_TYPE", "Chroma" ) @@ -56,22 +66,33 @@ class VectorStoreConnector: vector_store_config = vector_store_config or ChromaVectorConfig() vector_store_config.embedding_fn = embedding_fn - return cls(vector_store_type, vector_store_config) + real_vector_store_type = cast(str, vector_store_type) + return cls(real_vector_store_type, vector_store_config) def load_document(self, chunks: List[Chunk]) -> List[str]: - """load document in vector database. + """Load document in vector database. + Args: - chunks: document chunks. Return chunk ids. """ + max_chunks_once_load = ( + self._vector_store_config.max_chunks_once_load + if self._vector_store_config + else 10 + ) + max_threads = ( + self._vector_store_config.max_threads if self._vector_store_config else 1 + ) return self.client.load_document_with_limit( chunks, - self._vector_store_config.max_chunks_once_load, - self._vector_store_config.max_threads, + max_chunks_once_load, + max_threads, ) def similar_search(self, doc: str, topk: int) -> List[Chunk]: - """similar search in vector database. + """Similar search in vector database. + Args: - doc: query text - topk: topk @@ -83,14 +104,17 @@ class VectorStoreConnector: def similar_search_with_scores( self, doc: str, topk: int, score_threshold: float ) -> List[Chunk]: - """ + """Similar search with scores in vector database. + similar_search_with_score in vector database.. Return docs and relevance scores in the range [0, 1]. + Args: - - doc(str): query text - - topk(int): return docs nums. Defaults to 4. - - score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar. + doc(str): query text + topk(int): return docs nums. Defaults to 4. + score_threshold(float): score_threshold: Optional, a floating point value + between 0 to 1 to filter the resulting set of retrieved docs,0 is + dissimilar, 1 is most similar. Return: - chunks: chunks. """ @@ -98,32 +122,33 @@ class VectorStoreConnector: @property def vector_store_config(self) -> VectorStoreConfig: - """vector store config.""" + """Return the vector store config.""" + if not self._vector_store_config: + raise ValueError("vector store config not set.") return self._vector_store_config def vector_name_exists(self): - """is vector store name exist.""" + """Whether vector name exists.""" return self.client.vector_name_exists() - def delete_vector_name(self, vector_name): - """vector store delete + def delete_vector_name(self, vector_name: str): + """Delete vector name. + Args: - vector_name: vector store name """ return self.client.delete_vector_name(vector_name) def delete_by_ids(self, ids): - """vector store delete by ids. + """Delete vector by ids. + Args: - ids: vector ids """ return self.client.delete_by_ids(ids=ids) def _match(self, vector_store_type) -> bool: - if connector.get(vector_store_type): - return True - else: - return False + return bool(connector.get(vector_store_type)) def _register(self): for cls in vector_store.__all__: diff --git a/dbgpt/storage/vector_store/milvus_store.py b/dbgpt/storage/vector_store/milvus_store.py index 2748b5a1f..58281d2c2 100644 --- a/dbgpt/storage/vector_store/milvus_store.py +++ b/dbgpt/storage/vector_store/milvus_store.py @@ -1,13 +1,14 @@ +"""Milvus vector store.""" from __future__ import annotations import json import logging import os -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Iterable, List, Optional -from pydantic import Field - -from dbgpt.rag.chunk import Chunk, Document +from dbgpt._private.pydantic import Field +from dbgpt.core import Embeddings +from dbgpt.rag.chunk import Chunk from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig from dbgpt.util import string_utils @@ -17,6 +18,11 @@ logger = logging.getLogger(__name__) class MilvusVectorConfig(VectorStoreConfig): """Milvus vector store config.""" + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + uri: str = Field( default="localhost", description="The uri of milvus store, if not set, will use the default uri.", @@ -28,7 +34,8 @@ class MilvusVectorConfig(VectorStoreConfig): alias: str = Field( default="default", - description="The alias of milvus store, if not set, will use the default alias.", + description="The alias of milvus store, if not set, will use the default " + "alias.", ) user: str = Field( default=None, @@ -36,35 +43,42 @@ class MilvusVectorConfig(VectorStoreConfig): ) password: str = Field( default=None, - description="The password of milvus store, if not set, will use the default password.", + description="The password of milvus store, if not set, will use the default " + "password.", ) primary_field: str = Field( default="pk_id", - description="The primary field of milvus store, if not set, will use the default primary field.", + description="The primary field of milvus store, if not set, will use the " + "default primary field.", ) text_field: str = Field( default="content", - description="The text field of milvus store, if not set, will use the default text field.", + description="The text field of milvus store, if not set, will use the default " + "text field.", ) embedding_field: str = Field( default="vector", - description="The embedding field of milvus store, if not set, will use the default embedding field.", + description="The embedding field of milvus store, if not set, will use the " + "default embedding field.", ) metadata_field: str = Field( default="metadata", - description="The metadata field of milvus store, if not set, will use the default metadata field.", + description="The metadata field of milvus store, if not set, will use the " + "default metadata field.", ) secure: str = Field( default="", - description="The secure of milvus store, if not set, will use the default secure.", + description="The secure of milvus store, if not set, will use the default " + "secure.", ) class MilvusStore(VectorStoreBase): - """Milvus database""" + """Milvus vector store.""" def __init__(self, vector_store_config: MilvusVectorConfig) -> None: - """MilvusStore init. + """Create a MilvusStore instance. + Args: vector_store_config (MilvusVectorConfig): MilvusStore config. refer to https://milvus.io/docs/v2.0.x/manage_connection.md @@ -93,8 +107,11 @@ class MilvusStore(VectorStoreBase): hex_str = bytes_str.hex() self.collection_name = hex_str - self.embedding = vector_store_config.embedding_fn - self.fields = [] + if not vector_store_config.embedding_fn: + raise ValueError("embedding is required for MilvusStore") + + self.embedding: Embeddings = vector_store_config.embedding_fn + self.fields: List = [] self.alias = milvus_vector_config.get("alias") or "default" # use HNSW by default. @@ -124,7 +141,8 @@ class MilvusStore(VectorStoreBase): if (self.username is None) != (self.password is None): raise ValueError( - "Both username and password must be set to use authentication for Milvus" + "Both username and password must be set to use authentication for " + "Milvus" ) if self.username: connect_kwargs["user"] = self.username @@ -139,7 +157,10 @@ class MilvusStore(VectorStoreBase): ) def init_schema_and_load(self, vector_name, documents) -> List[str]: - """Create a Milvus collection, indexes it with HNSW, load document. + """Create a Milvus collection. + + Create a Milvus collection, indexes it with HNSW, load document. + Args: vector_name (Embeddings): your collection name. documents (List[str]): Text to insert. @@ -155,7 +176,7 @@ class MilvusStore(VectorStoreBase): connections, utility, ) - from pymilvus.orm.types import infer_dtype_bydata + from pymilvus.orm.types import infer_dtype_bydata # noqa: F401 except ImportError: raise ValueError( "Could not import pymilvus python package. " @@ -240,10 +261,10 @@ class MilvusStore(VectorStoreBase): partition_name: Optional[str] = None, timeout: Optional[int] = None, ) -> List[str]: - """add text data into Milvus.""" + """Add text data into Milvus.""" insert_dict: Any = {self.text_field: list(texts)} try: - import numpy as np + import numpy as np # noqa: F401 text_vector = self.embedding.embed_documents(list(texts)) insert_dict[self.vector_field] = text_vector @@ -268,7 +289,7 @@ class MilvusStore(VectorStoreBase): return res.primary_keys def load_document(self, chunks: List[Chunk]) -> List[str]: - """load document in vector database.""" + """Load document in vector database.""" batch_size = 500 batched_list = [ chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size) @@ -280,6 +301,7 @@ class MilvusStore(VectorStoreBase): return doc_ids def similar_search(self, text, topk) -> List[Chunk]: + """Perform a search on a query string and return results.""" from pymilvus import Collection, DataType """similar_search in vector database.""" @@ -409,12 +431,14 @@ class MilvusStore(VectorStoreBase): return ret[0], ret def vector_name_exists(self): + """Whether vector name exists.""" from pymilvus import utility """is vector store name exist.""" return utility.has_collection(self.collection_name) - def delete_vector_name(self, vector_name): + def delete_vector_name(self, vector_name: str): + """Delete vector name.""" from pymilvus import utility """milvus delete collection name""" @@ -423,11 +447,12 @@ class MilvusStore(VectorStoreBase): return True def delete_by_ids(self, ids): + """Delete vector by ids.""" from pymilvus import Collection self.col = Collection(self.collection_name) - """milvus delete vectors by ids""" - logger.info(f"begin delete milvus ids...") + # milvus delete vectors by ids + logger.info(f"begin delete milvus ids: {ids}") delete_ids = ids.split(",") doc_ids = [int(doc_id) for doc_id in delete_ids] delet_expr = f"{self.primary_field} in {doc_ids}" diff --git a/dbgpt/storage/vector_store/pgvector_store.py b/dbgpt/storage/vector_store/pgvector_store.py index 2563a2893..7038e4c7e 100644 --- a/dbgpt/storage/vector_store/pgvector_store.py +++ b/dbgpt/storage/vector_store/pgvector_store.py @@ -1,9 +1,9 @@ +"""Postgres vector store.""" import logging from typing import Any, List -from pydantic import Field - from dbgpt._private.config import Config +from dbgpt._private.pydantic import Field from dbgpt.rag.chunk import Chunk from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig @@ -15,21 +15,26 @@ CFG = Config() class PGVectorConfig(VectorStoreConfig): """PG vector store config.""" + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + connection_string: str = Field( default=None, - description="the connection string of vector store, if not set, will use the default connection string.", + description="the connection string of vector store, if not set, will use the " + "default connection string.", ) class PGVectorStore(VectorStoreBase): - """`Postgres.PGVector` vector store. + """PG vector store. To use this, you should have the ``pgvector`` python package installed. """ def __init__(self, vector_store_config: PGVectorConfig) -> None: - """init pgvector storage""" - + """Create a PGVectorStore instance.""" from langchain.vectorstores import PGVector self.connection_string = vector_store_config.connection_string @@ -42,23 +47,43 @@ class PGVectorStore(VectorStoreBase): connection_string=self.connection_string, ) - def similar_search(self, text, topk, **kwargs: Any) -> None: + def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]: + """Perform similar search in PGVector.""" return self.vector_store_client.similarity_search(text, topk) - def vector_name_exists(self): + def vector_name_exists(self) -> bool: + """Check if vector name exists.""" try: self.vector_store_client.create_collection() return True except Exception as e: - logger.error("vector_name_exists error", e.message) + logger.error(f"vector_name_exists error, {str(e)}") return False def load_document(self, chunks: List[Chunk]) -> List[str]: + """Load document to PGVector. + + Args: + chunks(List[Chunk]): document chunks. + + Return: + List[str]: chunk ids. + """ lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks] return self.vector_store_client.from_documents(lc_documents) - def delete_vector_name(self, vector_name): + def delete_vector_name(self, vector_name: str): + """Delete vector by name. + + Args: + vector_name(str): vector name. + """ return self.vector_store_client.delete_collection() - def delete_by_ids(self, ids): + def delete_by_ids(self, ids: str): + """Delete vector by ids. + + Args: + ids(str): vector ids, separated by comma. + """ return self.vector_store_client.delete(ids) diff --git a/dbgpt/storage/vector_store/weaviate_store.py b/dbgpt/storage/vector_store/weaviate_store.py index 77744489e..776d40e08 100644 --- a/dbgpt/storage/vector_store/weaviate_store.py +++ b/dbgpt/storage/vector_store/weaviate_store.py @@ -1,14 +1,13 @@ +"""Weaviate vector store.""" import logging import os from typing import List -from langchain.schema import Document -from pydantic import Field - from dbgpt._private.config import Config -from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH +from dbgpt._private.pydantic import Field from dbgpt.rag.chunk import Chunk -from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig + +from .base import VectorStoreBase, VectorStoreConfig logger = logging.getLogger(__name__) CFG = Config() @@ -17,6 +16,11 @@ CFG = Config() class WeaviateVectorConfig(VectorStoreConfig): """Weaviate vector store config.""" + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + weaviate_url: str = Field( default=os.getenv("WEAVIATE_URL", None), description="weaviate url address, if not set, will use the default url.", @@ -28,7 +32,7 @@ class WeaviateVectorConfig(VectorStoreConfig): class WeaviateStore(VectorStoreBase): - """Weaviate database""" + """Weaviate database.""" def __init__(self, vector_store_config: WeaviateVectorConfig) -> None: """Initialize with Weaviate client.""" @@ -49,8 +53,8 @@ class WeaviateStore(VectorStoreBase): self.vector_store_client = weaviate.Client(self.weaviate_url) - def similar_search(self, text: str, topk: int) -> None: - """Perform similar search in Weaviate""" + def similar_search(self, text: str, topk: int) -> List[Chunk]: + """Perform similar search in Weaviate.""" logger.info("Weaviate similar search") # nearText = { # "concepts": [text], @@ -68,15 +72,16 @@ class WeaviateStore(VectorStoreBase): docs = [] for r in res: docs.append( - Document( - page_content=r["page_content"], + Chunk( + content=r["page_content"], metadata={"metadata": r["metadata"]}, ) ) return docs def vector_name_exists(self) -> bool: - """Check if a vector name exists for a given class in Weaviate. + """Whether the vector name exists in Weaviate. + Returns: bool: True if the vector name exists, False otherwise. """ @@ -85,14 +90,15 @@ class WeaviateStore(VectorStoreBase): return True return False except Exception as e: - logger.error("vector_name_exists error", e.message) + logger.error(f"vector_name_exists error, {str(e)}") return False def _default_schema(self) -> None: - """ - Create the schema for Weaviate with a Document class containing metadata and text properties. - """ + """Create default schema in Weaviate. + Create the schema for Weaviate with a Document class containing metadata and + text properties. + """ schema = { "classes": [ { @@ -137,7 +143,7 @@ class WeaviateStore(VectorStoreBase): self.vector_store_client.schema.create(schema) def load_document(self, chunks: List[Chunk]) -> List[str]: - """Load documents into Weaviate""" + """Load document to Weaviate.""" logger.info("Weaviate load document") texts = [doc.content for doc in chunks] metadatas = [doc.metadata for doc in chunks] @@ -157,3 +163,5 @@ class WeaviateStore(VectorStoreBase): data_object=properties, class_name=self.vector_name ) self.vector_store_client.batch.flush() + # TODO: return ids + return [] diff --git a/requirements/lint-requirements.txt b/requirements/lint-requirements.txt index c8d53cc5d..4e9e454e8 100644 --- a/requirements/lint-requirements.txt +++ b/requirements/lint-requirements.txt @@ -11,4 +11,5 @@ isort==5.10.1 pyupgrade==3.1.0 types-requests types-beautifulsoup4 -types-Markdown \ No newline at end of file +types-Markdown +types-tqdm \ No newline at end of file