mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
chore: Add pylint for storage (#1298)
This commit is contained in:
19
.mypy.ini
19
.mypy.ini
@@ -8,8 +8,8 @@ follow_imports = skip
|
||||
[mypy-dbgpt.datasource.*]
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-dbgpt.storage.*]
|
||||
follow_imports = skip
|
||||
# [mypy-dbgpt.storage.*]
|
||||
# follow_imports = skip
|
||||
|
||||
[mypy-dbgpt.serve.*]
|
||||
follow_imports = skip
|
||||
@@ -57,4 +57,17 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-spacy.*]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
||||
follow_imports = skip
|
||||
|
||||
# Storage
|
||||
[mypy-msgpack.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-rocksdict.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-weaviate.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-pymilvus.*]
|
||||
ignore_missing_imports = True
|
||||
|
8
Makefile
8
Makefile
@@ -50,6 +50,7 @@ fmt: setup ## Format Python code
|
||||
# https://flake8.pycqa.org/en/latest/
|
||||
$(VENV_BIN)/flake8 dbgpt/core/
|
||||
$(VENV_BIN)/flake8 dbgpt/rag/
|
||||
$(VENV_BIN)/flake8 dbgpt/storage/
|
||||
# TODO: More package checks with flake8.
|
||||
|
||||
.PHONY: fmt-check
|
||||
@@ -60,8 +61,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
|
||||
$(VENV_BIN)/blackdoc --check dbgpt examples
|
||||
$(VENV_BIN)/flake8 dbgpt/core/
|
||||
$(VENV_BIN)/flake8 dbgpt/rag/
|
||||
# $(VENV_BIN)/blackdoc --check dbgpt examples
|
||||
# $(VENV_BIN)/flake8 dbgpt/core/
|
||||
$(VENV_BIN)/flake8 dbgpt/storage/
|
||||
|
||||
.PHONY: pre-commit
|
||||
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
|
||||
@@ -77,8 +77,10 @@ test-doc: $(VENV)/.testenv ## Run doctests
|
||||
.PHONY: mypy
|
||||
mypy: $(VENV)/.testenv ## Run mypy checks
|
||||
# https://github.com/python/mypy
|
||||
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
|
||||
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/
|
||||
# rag depends on core and storage, so we not need to check it again.
|
||||
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
|
||||
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
|
||||
# TODO: More package checks with mypy.
|
||||
|
||||
.PHONY: coverage
|
||||
|
@@ -83,7 +83,7 @@ CREATE TABLE IF NOT EXISTS `chat_history`
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
|
||||
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
||||
`summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
||||
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
||||
`message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',
|
||||
|
@@ -0,0 +1,4 @@
|
||||
"""Old chat history module.
|
||||
|
||||
Just used by editor.
|
||||
"""
|
@@ -8,10 +8,10 @@ from dbgpt.core.interface.message import OnceConversation
|
||||
|
||||
|
||||
class MemoryStoreType(Enum):
|
||||
File = "file"
|
||||
Memory = "memory"
|
||||
# File = "file"
|
||||
# Memory = "memory"
|
||||
DB = "db"
|
||||
DuckDb = "duckdb"
|
||||
# DuckDb = "duckdb"
|
||||
|
||||
|
||||
class BaseChatHistoryMemory(ABC):
|
||||
@@ -24,18 +24,14 @@ class BaseChatHistoryMemory(ABC):
|
||||
def messages(self) -> List[OnceConversation]: # type: ignore
|
||||
"""Retrieve the messages from the local file"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, user_name: str) -> None:
|
||||
"""Append the message to the record in the local file"""
|
||||
# @abstractmethod
|
||||
# def create(self, user_name: str) -> None:
|
||||
# """Append the message to the record in the local file"""
|
||||
|
||||
@abstractmethod
|
||||
def append(self, message: OnceConversation) -> None:
|
||||
"""Append the message to the record in the local file"""
|
||||
|
||||
# @abstractmethod
|
||||
# def clear(self) -> None:
|
||||
# """Clear session memory from the local file"""
|
||||
|
||||
@abstractmethod
|
||||
def update(self, messages: List[OnceConversation]) -> None:
|
||||
pass
|
||||
@@ -45,14 +41,11 @@ class BaseChatHistoryMemory(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conv_info(self, conv_uid: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
def get_messages(self) -> List[Dict]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def conv_list(
|
||||
user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||
) -> List[Dict]:
|
@@ -1,13 +1,16 @@
|
||||
"""Module for chat history factory.
|
||||
|
||||
It will remove in the future, just support db store type now.
|
||||
"""
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
|
||||
|
||||
from .base import MemoryStoreType
|
||||
from .base import BaseChatHistoryMemory, MemoryStoreType
|
||||
|
||||
# Import first for auto create table
|
||||
from .store_type.meta_db_history import DbHistoryMemory
|
||||
from .meta_db_history import DbHistoryMemory
|
||||
|
||||
# TODO remove global variable
|
||||
CFG = Config()
|
||||
@@ -20,13 +23,6 @@ class ChatHistory:
|
||||
self.mem_store_class_map = {}
|
||||
|
||||
# Just support db store type after v0.4.6
|
||||
# from .store_type.duckdb_history import DuckdbHistoryMemory
|
||||
# from .store_type.file_history import FileHistoryMemory
|
||||
# from .store_type.mem_history import MemHistoryMemory
|
||||
# self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
|
||||
# self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
|
||||
# self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
|
||||
|
||||
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
|
||||
|
||||
def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
|
||||
@@ -53,24 +49,26 @@ class ChatHistory:
|
||||
Raises:
|
||||
ValueError: Invalid store type
|
||||
"""
|
||||
from .store_type.duckdb_history import DuckdbHistoryMemory
|
||||
from .store_type.file_history import FileHistoryMemory
|
||||
from .store_type.mem_history import MemHistoryMemory
|
||||
|
||||
if store_type == MemHistoryMemory.store_type:
|
||||
if store_type == "memory":
|
||||
logger.error(
|
||||
"Not support memory store type, just support db store type now"
|
||||
)
|
||||
raise ValueError(f"Invalid store type: {store_type}")
|
||||
|
||||
if store_type == FileHistoryMemory.store_type:
|
||||
if store_type == "file":
|
||||
logger.error("Not support file store type, just support db store type now")
|
||||
raise ValueError(f"Invalid store type: {store_type}")
|
||||
if store_type == DuckdbHistoryMemory.store_type:
|
||||
link1 = "https://docs.dbgpt.site/docs/faq/install#q6-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-sqlitel"
|
||||
link2 = "https://docs.dbgpt.site/docs/faq/install#q7-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-mysql"
|
||||
if store_type == "duckdb":
|
||||
link1 = (
|
||||
"https://docs.dbgpt.site/docs/latest/faq/install#q6-how-to-migrate-meta"
|
||||
"-table-chat_history-and-connect_config-from-duckdb-to-sqlite"
|
||||
)
|
||||
link2 = (
|
||||
"https://docs.dbgpt.site/docs/latest/faq/install/#q7-how-to-migrate-"
|
||||
"meta-table-chat_history-and-connect_config-from-duckdb-to-mysql"
|
||||
)
|
||||
logger.error(
|
||||
"Not support duckdb store type after v0.4.6, just support db store type now, "
|
||||
f"you can migrate your message according to {link1} or {link2}"
|
||||
"Not support duckdb store type after v0.4.6, just support db store "
|
||||
f"type now, you can migrate your message according to {link1} or {link2}"
|
||||
)
|
||||
raise ValueError(f"Invalid store type: {store_type}")
|
@@ -4,64 +4,76 @@ from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
|
||||
from dbgpt.storage.chat_history.chat_history_db import ChatHistoryDao, ChatHistoryEntity
|
||||
|
||||
from .base import BaseChatHistoryMemory, MemoryStoreType
|
||||
|
||||
CFG = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.DB.value
|
||||
"""Db history memory storage.
|
||||
|
||||
It is deprecated.
|
||||
"""
|
||||
|
||||
store_type: str = MemoryStoreType.DB.value # type: ignore
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
self.chat_history_dao = ChatHistoryDao()
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
||||
chat_history: Optional[ChatHistoryEntity] = self.chat_history_dao.get_by_uid(
|
||||
self.chat_seesion_id
|
||||
)
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
if context:
|
||||
conversations: List[OnceConversation] = json.loads(context)
|
||||
conversations: List[OnceConversation] = json.loads(
|
||||
context # type: ignore
|
||||
)
|
||||
return conversations
|
||||
return []
|
||||
|
||||
def create(self, chat_mode, summary: str, user_name: str) -> None:
|
||||
try:
|
||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||
chat_history.chat_mode = chat_mode
|
||||
chat_history.summary = summary
|
||||
chat_history.user_name = user_name
|
||||
|
||||
self.chat_history_dao.raw_update(chat_history)
|
||||
except Exception as e:
|
||||
logger.error("init create conversation log error!" + str(e))
|
||||
|
||||
# def create(self, chat_mode, summary: str, user_name: str) -> None:
|
||||
# try:
|
||||
# chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||
# chat_history.chat_mode = chat_mode
|
||||
# chat_history.summary = summary
|
||||
# chat_history.user_name = user_name
|
||||
#
|
||||
# self.chat_history_dao.raw_update(chat_history)
|
||||
# except Exception as e:
|
||||
# logger.error("init create conversation log error!" + str(e))
|
||||
#
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
logger.debug(f"db history append: {once_message}")
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
||||
chat_history: Optional[ChatHistoryEntity] = self.chat_history_dao.get_by_uid(
|
||||
self.chat_seesion_id
|
||||
)
|
||||
conversations: List[OnceConversation] = []
|
||||
conversations: List[Dict] = []
|
||||
latest_user_message = once_message.get_latest_user_message()
|
||||
summary = latest_user_message.content if latest_user_message else ""
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
if context:
|
||||
conversations = json.loads(context)
|
||||
conversations = json.loads(context) # type: ignore
|
||||
else:
|
||||
chat_history.summary = once_message.get_latest_user_message().content
|
||||
chat_history.summary = summary # type: ignore
|
||||
else:
|
||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||
chat_history.conv_uid = self.chat_seesion_id
|
||||
chat_history.chat_mode = once_message.chat_mode
|
||||
chat_history.user_name = once_message.user_name
|
||||
chat_history.sys_code = once_message.sys_code
|
||||
chat_history.summary = once_message.get_latest_user_message().content
|
||||
chat_history = ChatHistoryEntity()
|
||||
chat_history.conv_uid = self.chat_seesion_id # type: ignore
|
||||
chat_history.chat_mode = once_message.chat_mode # type: ignore
|
||||
chat_history.user_name = once_message.user_name # type: ignore
|
||||
chat_history.sys_code = once_message.sys_code # type: ignore
|
||||
chat_history.summary = summary # type: ignore
|
||||
|
||||
conversations.append(_conversation_to_dict(once_message))
|
||||
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
|
||||
chat_history.messages = json.dumps( # type: ignore
|
||||
conversations, ensure_ascii=False
|
||||
)
|
||||
|
||||
self.chat_history_dao.raw_update(chat_history)
|
||||
|
||||
@@ -72,18 +84,13 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def delete(self) -> bool:
|
||||
self.chat_history_dao.raw_delete(self.chat_seesion_id)
|
||||
return True
|
||||
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
logger.info("conv_info:{}", conv_uid)
|
||||
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
|
||||
return chat_history.__dict__
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
# logger.info("get_messages:{}", self.chat_seesion_id)
|
||||
def get_messages(self) -> List[Dict]:
|
||||
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
return json.loads(context)
|
||||
return json.loads(context) # type: ignore
|
||||
return []
|
||||
|
||||
@staticmethod
|
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
@@ -23,9 +23,9 @@ from dbgpt.app.openapi.editor_view_model import (
|
||||
)
|
||||
from dbgpt.app.scene import ChatFactory
|
||||
from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
from ._chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@@ -201,7 +201,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
history_messages: List[Dict] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
|
||||
|
@@ -21,10 +21,9 @@ from dbgpt.core.awel import (
|
||||
from dbgpt.core.operators import (
|
||||
BufferedConversationMapperOperator,
|
||||
HistoryPromptBuilderOperator,
|
||||
LLMBranchOperator,
|
||||
)
|
||||
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
|
||||
from dbgpt.storage.cache.operator import (
|
||||
from dbgpt.storage.cache.operators import (
|
||||
CachedModelOperator,
|
||||
CachedModelStreamOperator,
|
||||
CacheManager,
|
||||
|
@@ -33,7 +33,7 @@ async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIter
|
||||
yield iter_data
|
||||
|
||||
|
||||
class IteratorTrigger(Trigger):
|
||||
class IteratorTrigger(Trigger[List[Tuple[Any, Any]]]):
|
||||
"""Trigger for iterator data.
|
||||
|
||||
Trigger the dag with iterator data.
|
||||
@@ -46,6 +46,7 @@ class IteratorTrigger(Trigger):
|
||||
data: IterDataType,
|
||||
parallel_num: int = 1,
|
||||
streaming_call: bool = False,
|
||||
show_progress: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a IteratorTrigger.
|
||||
@@ -60,6 +61,7 @@ class IteratorTrigger(Trigger):
|
||||
self._iter_data = data
|
||||
self._parallel_num = parallel_num
|
||||
self._streaming_call = streaming_call
|
||||
self._show_progress = show_progress
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def trigger(
|
||||
@@ -132,17 +134,27 @@ class IteratorTrigger(Trigger):
|
||||
async def call_stream(call_data: Any):
|
||||
async for out in await end_node.call_stream(call_data):
|
||||
yield out
|
||||
await dag._after_dag_end()
|
||||
|
||||
async def run_node(call_data: Any):
|
||||
async def run_node(call_data: Any) -> Tuple[Any, Any]:
|
||||
async with semaphore:
|
||||
if streaming_call:
|
||||
task_output = call_stream(call_data)
|
||||
else:
|
||||
task_output = await end_node.call(call_data)
|
||||
await dag._after_dag_end()
|
||||
return call_data, task_output
|
||||
|
||||
tasks = []
|
||||
|
||||
if self._show_progress:
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
async_module = tqdm_asyncio
|
||||
else:
|
||||
async_module = asyncio # type: ignore
|
||||
|
||||
async for data in _to_async_iterator(self._iter_data, task_id):
|
||||
tasks.append(run_node(data))
|
||||
results = await asyncio.gather(*tasks)
|
||||
results: List[Tuple[Any, Any]] = await async_module.gather(*tasks)
|
||||
return results
|
||||
|
@@ -10,7 +10,7 @@ from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
|
||||
|
||||
from dbgpt.rag.embedding import Embeddings
|
||||
from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings
|
||||
|
||||
|
||||
class EmbeddingLoader:
|
||||
@@ -47,7 +47,7 @@ class EmbeddingLoader:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIEmbeddings(**openapi_param)
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from dbgpt.rag.embedding import HuggingFaceEmbeddings
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager
|
||||
|
||||
|
||||
|
@@ -20,14 +20,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingsModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
|
||||
"Please install it with `pip install langchain`."
|
||||
) from exc
|
||||
from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings
|
||||
|
||||
self._embeddings_impl: Embeddings = None
|
||||
self._model_params = None
|
||||
self.model_name = None
|
||||
|
@@ -3,8 +3,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.rag.chunk import Chunk, Document
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
|
||||
|
@@ -411,6 +411,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
async for output in _chat_with_dag_task(end_node, request, incremental):
|
||||
yield output
|
||||
await dag._after_dag_end()
|
||||
|
||||
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
|
||||
"""Parse the flow category
|
||||
|
@@ -0,0 +1,3 @@
|
||||
"""Module of storage."""
|
||||
|
||||
from .schema import DBType # noqa: F401
|
||||
|
7
dbgpt/storage/cache/__init__.py
vendored
7
dbgpt/storage/cache/__init__.py
vendored
@@ -1,6 +1,7 @@
|
||||
from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
|
||||
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
|
||||
"""Module for cache storage."""
|
||||
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue # noqa: F401
|
||||
from .manager import CacheManager, initialize_cache # noqa: F401
|
||||
from .storage.base import MemoryCacheStorage # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"LLMCacheKey",
|
||||
|
1
dbgpt/storage/cache/embedding_cache.py
vendored
1
dbgpt/storage/cache/embedding_cache.py
vendored
@@ -0,0 +1 @@
|
||||
"""Embeddings cache."""
|
||||
|
83
dbgpt/storage/cache/llm_cache.py
vendored
83
dbgpt/storage/cache/llm_cache.py
vendored
@@ -1,21 +1,26 @@
|
||||
"""Cache client for LLM."""
|
||||
|
||||
import hashlib
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import ModelOutput, Serializer
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.storage.cache.manager import CacheManager
|
||||
|
||||
from .manager import CacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheKeyData:
|
||||
"""Cache key data for LLM."""
|
||||
|
||||
prompt: str
|
||||
model_name: str
|
||||
temperature: Optional[float] = 0.7
|
||||
max_new_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
model_type: Optional[str] = ModelType.HF
|
||||
# See dbgpt.model.base.ModelType
|
||||
model_type: Optional[str] = "huggingface"
|
||||
|
||||
|
||||
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
@@ -23,12 +28,15 @@ CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
|
||||
@dataclass
|
||||
class LLMCacheValueData:
|
||||
"""Cache value data for LLM."""
|
||||
|
||||
output: CacheOutputType
|
||||
user: Optional[str] = None
|
||||
_is_list: Optional[bool] = False
|
||||
_is_list: bool = False
|
||||
|
||||
@staticmethod
|
||||
def from_dict(**kwargs) -> "LLMCacheValueData":
|
||||
"""Create LLMCacheValueData object from dict."""
|
||||
output = kwargs.get("output")
|
||||
if not output:
|
||||
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
||||
@@ -46,6 +54,7 @@ class LLMCacheValueData:
|
||||
return LLMCacheValueData(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
output = self.output
|
||||
is_list = False
|
||||
if isinstance(output, list):
|
||||
@@ -53,16 +62,18 @@ class LLMCacheValueData:
|
||||
is_list = True
|
||||
for out in output:
|
||||
output_list.append(out.to_dict())
|
||||
output = output_list
|
||||
output = output_list # type: ignore
|
||||
else:
|
||||
output = output.to_dict()
|
||||
output = output.to_dict() # type: ignore
|
||||
return {"output": output, "_is_list": is_list, "user": self.user}
|
||||
|
||||
@property
|
||||
def is_list(self) -> bool:
|
||||
"""Return whether the output is a list."""
|
||||
return self._is_list
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string representation."""
|
||||
if not isinstance(self.output, list):
|
||||
return f"user: {self.user}, output: {self.output}"
|
||||
else:
|
||||
@@ -70,74 +81,116 @@ class LLMCacheValueData:
|
||||
|
||||
|
||||
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||
"""Cache key for LLM."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Create a new instance of LLMCacheKey."""
|
||||
super().__init__()
|
||||
self.config = LLMCacheKeyData(**kwargs)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of the object."""
|
||||
serialize_bytes = self.serialize()
|
||||
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality with another key."""
|
||||
if not isinstance(other, LLMCacheKey):
|
||||
return False
|
||||
return self.config == other.config
|
||||
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
"""Return the byte array of hash value.
|
||||
|
||||
Returns:
|
||||
bytes: The byte array of hash value.
|
||||
"""
|
||||
serialize_bytes = self.serialize()
|
||||
return hashlib.sha256(serialize_bytes).digest()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return asdict(self.config)
|
||||
|
||||
def get_value(self) -> LLMCacheKeyData:
|
||||
"""Return the real object of current cache key."""
|
||||
return self.config
|
||||
|
||||
|
||||
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||
"""Cache value for LLM."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Create a new instance of LLMCacheValue."""
|
||||
super().__init__()
|
||||
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return self.value.to_dict()
|
||||
|
||||
def get_value(self) -> LLMCacheValueData:
|
||||
"""Return the underlying real value."""
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string representation."""
|
||||
return f"value: {str(self.value)}"
|
||||
|
||||
|
||||
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||
"""Cache client for LLM."""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager) -> None:
|
||||
"""Create a new instance of LLMCacheClient."""
|
||||
super().__init__()
|
||||
self._cache_manager: CacheManager = cache_manager
|
||||
|
||||
async def get(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
self,
|
||||
key: LLMCacheKey, # type: ignore
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> Optional[LLMCacheValue]:
|
||||
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
|
||||
"""Retrieve a value from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (LLMCacheKey): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[LLMCacheValue]: The value retrieved according to key. If cache key
|
||||
not exist, return None.
|
||||
"""
|
||||
return cast(
|
||||
LLMCacheValue,
|
||||
await self._cache_manager.get(key, LLMCacheValue, cache_config),
|
||||
)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: LLMCacheKey,
|
||||
value: LLMCacheValue,
|
||||
key: LLMCacheKey, # type: ignore
|
||||
value: LLMCacheValue, # type: ignore
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key."""
|
||||
return await self._cache_manager.set(key, value, cache_config)
|
||||
|
||||
async def exists(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
self,
|
||||
key: LLMCacheKey, # type: ignore
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> bool:
|
||||
"""Check if a key exists in the cache."""
|
||||
return await self.get(key, cache_config) is not None
|
||||
|
||||
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||
def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore
|
||||
"""Create a cache key with params."""
|
||||
key = LLMCacheKey(**kwargs)
|
||||
key.set_serializer(self._cache_manager.serializer)
|
||||
return key
|
||||
|
||||
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||
def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore
|
||||
"""Create a cache value with params."""
|
||||
value = LLMCacheValue(**kwargs)
|
||||
value.set_serializer(self._cache_manager.serializer)
|
||||
return value
|
||||
|
54
dbgpt/storage/cache/manager.py
vendored
54
dbgpt/storage/cache/manager.py
vendored
@@ -1,24 +1,31 @@
|
||||
"""Cache manager."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Executor
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, cast
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer
|
||||
from dbgpt.core.interface.cache import K, V
|
||||
from dbgpt.storage.cache.storage.base import CacheStorage
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
from .storage.base import CacheStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheManager(BaseComponent, ABC):
|
||||
"""The cache manager interface."""
|
||||
|
||||
name = ComponentType.MODEL_CACHE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
"""Create cache manager."""
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize cache manager."""
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
@@ -28,7 +35,7 @@ class CacheManager(BaseComponent, ABC):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache"""
|
||||
"""Set cache with key."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
@@ -36,27 +43,30 @@ class CacheManager(BaseComponent, ABC):
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
"""Get cache with key"""
|
||||
) -> Optional[CacheValue[V]]:
|
||||
"""Retrieve cache with key."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get cache serializer"""
|
||||
"""Return serializer to serialize/deserialize cache value."""
|
||||
|
||||
|
||||
class LocalCacheManager(CacheManager):
|
||||
"""Local cache manager."""
|
||||
|
||||
def __init__(
|
||||
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
|
||||
) -> None:
|
||||
"""Create local cache manager."""
|
||||
super().__init__(system_app)
|
||||
self._serializer = serializer
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def executor(self) -> Executor:
|
||||
"""Return executor to submit task"""
|
||||
self._executor = self.system_app.get_component(
|
||||
"""Return executor."""
|
||||
return self.system_app.get_component( # type: ignore
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
@@ -66,6 +76,7 @@ class LocalCacheManager(CacheManager):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache with key."""
|
||||
if self._storage.support_async():
|
||||
await self._storage.aset(key, value, cache_config)
|
||||
else:
|
||||
@@ -78,7 +89,8 @@ class LocalCacheManager(CacheManager):
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
) -> Optional[CacheValue[V]]:
|
||||
"""Retrieve cache with key."""
|
||||
if self._storage.support_async():
|
||||
item_bytes = await self._storage.aget(key, cache_config)
|
||||
else:
|
||||
@@ -87,30 +99,42 @@ class LocalCacheManager(CacheManager):
|
||||
)
|
||||
if not item_bytes:
|
||||
return None
|
||||
return self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
return cast(
|
||||
CacheValue[V], self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
"""Return serializer to serialize/deserialize cache value."""
|
||||
return self._serializer
|
||||
|
||||
|
||||
def initialize_cache(
|
||||
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
|
||||
):
|
||||
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
|
||||
"""Initialize cache manager.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app.
|
||||
storage_type (str): The storage type.
|
||||
max_memory_mb (int): The max memory in MB.
|
||||
persist_dir (str): The persist directory.
|
||||
"""
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
cache_storage = None
|
||||
from .storage.base import MemoryCacheStorage
|
||||
|
||||
if storage_type == "disk":
|
||||
try:
|
||||
from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
from .storage.disk.disk_storage import DiskCacheStorage
|
||||
|
||||
cache_storage = DiskCacheStorage(
|
||||
cache_storage: CacheStorage = DiskCacheStorage(
|
||||
persist_dir, mem_table_buffer_mb=max_memory_mb
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error "
|
||||
f"message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
else:
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""Operators for processing model outputs with caching support."""
|
||||
import logging
|
||||
from typing import AsyncIterator, Dict, List, Union
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import ModelOutput, ModelRequest
|
||||
from dbgpt.core.awel import (
|
||||
@@ -10,7 +11,9 @@ from dbgpt.core.awel import (
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
|
||||
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
from .manager import CacheManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,15 +29,17 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
streamify: Processes a stream of inputs with cache support, yielding model outputs.
|
||||
streamify: Processes a stream of inputs with cache support, yielding model
|
||||
outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
"""Create a new instance of CachedModelStreamOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
|
||||
async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async def streamify(self, input_value: ModelRequest):
|
||||
"""Process inputs as a stream with cache support and yield model outputs.
|
||||
|
||||
Args:
|
||||
@@ -45,10 +50,13 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
llm_cache_value = await self._client.get(llm_cache_key)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
for out in llm_cache_value.get_value().output:
|
||||
yield out
|
||||
if not llm_cache_value:
|
||||
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
|
||||
outputs = cast(List[ModelOutput], llm_cache_value.get_value().output)
|
||||
for out in outputs:
|
||||
yield cast(ModelOutput, out)
|
||||
|
||||
|
||||
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
@@ -63,6 +71,7 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
"""Create a new instance of CachedModelOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
@@ -78,14 +87,18 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
llm_cache_value = await self._client.get(llm_cache_key)
|
||||
if not llm_cache_value:
|
||||
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
return llm_cache_value.get_value().output
|
||||
return cast(ModelOutput, llm_cache_value.get_value().output)
|
||||
|
||||
|
||||
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
"""
|
||||
A branch operator that decides whether to use cached data or to process data using the model.
|
||||
"""Branch operator for model processing with cache support.
|
||||
|
||||
A branch operator that decides whether to use cached data or to process data using
|
||||
the model.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for managing cache operations.
|
||||
@@ -101,6 +114,7 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
cache_task_name: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new instance of ModelCacheBranchOperator."""
|
||||
super().__init__(branches=None, **kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
@@ -110,10 +124,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
|
||||
"""Defines branch logic based on cache availability.
|
||||
"""Branch logic based on cache availability.
|
||||
|
||||
Defines branch logic based on cache availability.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
|
||||
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping
|
||||
branch functions to task names.
|
||||
"""
|
||||
|
||||
async def check_cache_true(input_value: ModelRequest) -> bool:
|
||||
@@ -124,12 +141,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
cache_value = await self._client.get(cache_key)
|
||||
logger.debug(
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: "
|
||||
f"{cache_value}"
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
|
||||
)
|
||||
return True if cache_value else False
|
||||
return bool(cache_value)
|
||||
|
||||
async def check_cache_false(input_value: ModelRequest):
|
||||
# Inverse of check_cache_true
|
||||
@@ -152,22 +170,25 @@ class ModelStreamSaveCacheOperator(
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
"""Create a new instance of ModelStreamSaveCacheOperator."""
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Transforms the input stream by saving the outputs to cache.
|
||||
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
|
||||
"""Save the stream of model outputs to cache.
|
||||
|
||||
Transforms the input stream by saving the outputs to cache.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
|
||||
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model
|
||||
outputs.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
|
||||
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are
|
||||
saved to cache.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = None
|
||||
llm_cache_key: Optional[LLMCacheKey] = None
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
if not llm_cache_key:
|
||||
@@ -190,12 +211,13 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
"""Create a new instance of ModelSaveCacheOperator."""
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Saves a single model output to cache and returns it.
|
||||
"""Save model output to cache.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The output from the model to be cached.
|
||||
@@ -213,7 +235,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
|
||||
|
||||
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
|
||||
"""Parses and extracts relevant fields from input to form a cache key dictionary.
|
||||
"""Parse and extract relevant fields from input to form a cache key dictionary.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input dictionary containing model and prompt parameters.
|
1
dbgpt/storage/cache/protocol/__init__.py
vendored
Normal file
1
dbgpt/storage/cache/protocol/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""Module for protocol."""
|
1
dbgpt/storage/cache/storage/__init__.py
vendored
1
dbgpt/storage/cache/storage/__init__.py
vendored
@@ -0,0 +1 @@
|
||||
"""Module for cache storage implementation."""
|
||||
|
26
dbgpt/storage/cache/storage/base.py
vendored
26
dbgpt/storage/cache/storage/base.py
vendored
@@ -1,3 +1,4 @@
|
||||
"""Base cache storage class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
@@ -22,8 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class StorageItem:
|
||||
"""
|
||||
A class representing a storage item.
|
||||
"""A class representing a storage item.
|
||||
|
||||
This class encapsulates data related to a storage item, such as its length,
|
||||
the hash of the key, and the data for both the key and value.
|
||||
@@ -44,6 +44,7 @@ class StorageItem:
|
||||
def build_from(
|
||||
key_hash: bytes, key_data: bytes, value_data: bytes
|
||||
) -> "StorageItem":
|
||||
"""Build a StorageItem from the provided key and value data."""
|
||||
length = (
|
||||
32
|
||||
+ _get_object_bytes(key_hash)
|
||||
@@ -56,6 +57,7 @@ class StorageItem:
|
||||
|
||||
@staticmethod
|
||||
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
|
||||
"""Build a StorageItem from the provided key and value."""
|
||||
key_hash = key.get_hash_bytes()
|
||||
key_data = key.serialize()
|
||||
value_data = value.serialize()
|
||||
@@ -105,6 +107,8 @@ class StorageItem:
|
||||
|
||||
|
||||
class CacheStorage(ABC):
|
||||
"""Base class for cache storage."""
|
||||
|
||||
@abstractmethod
|
||||
def check_config(
|
||||
self,
|
||||
@@ -122,6 +126,7 @@ class CacheStorage(ABC):
|
||||
"""
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Check whether the storage support async operation."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
@@ -135,7 +140,8 @@ class CacheStorage(ABC):
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If
|
||||
cache key not exist, return None.
|
||||
"""
|
||||
|
||||
async def aget(
|
||||
@@ -148,7 +154,8 @@ class CacheStorage(ABC):
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to
|
||||
key. If cache key not exist, return None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -184,8 +191,11 @@ class CacheStorage(ABC):
|
||||
|
||||
|
||||
class MemoryCacheStorage(CacheStorage):
|
||||
"""A simple in-memory cache storage implementation."""
|
||||
|
||||
def __init__(self, max_memory_mb: int = 256):
|
||||
self.cache = OrderedDict()
|
||||
"""Create a new instance of MemoryCacheStorage."""
|
||||
self.cache: OrderedDict = OrderedDict()
|
||||
self.max_memory = max_memory_mb * 1024 * 1024
|
||||
self.current_memory_usage = 0
|
||||
|
||||
@@ -194,6 +204,7 @@ class MemoryCacheStorage(CacheStorage):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal."""
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
@@ -208,10 +219,11 @@ class MemoryCacheStorage(CacheStorage):
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key."""
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
# Exact match retrieval
|
||||
key_hash = hash(key)
|
||||
item: StorageItem = self.cache.get(key_hash)
|
||||
item: Optional[StorageItem] = self.cache.get(key_hash)
|
||||
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
if not item:
|
||||
@@ -226,6 +238,7 @@ class MemoryCacheStorage(CacheStorage):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key."""
|
||||
key_hash = hash(key)
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
# Calculate memory size of the new entry
|
||||
@@ -242,6 +255,7 @@ class MemoryCacheStorage(CacheStorage):
|
||||
def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
"""Check if the key exists in the cache."""
|
||||
return self.get(key, cache_config) is not None
|
||||
|
||||
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
|
||||
|
1
dbgpt/storage/cache/storage/disk/__init__.py
vendored
1
dbgpt/storage/cache/storage/disk/__init__.py
vendored
@@ -0,0 +1 @@
|
||||
"""Disk cache storage implementation."""
|
||||
|
22
dbgpt/storage/cache/storage/disk/disk_storage.py
vendored
22
dbgpt/storage/cache/storage/disk/disk_storage.py
vendored
@@ -1,3 +1,7 @@
|
||||
"""Disk storage for cache.
|
||||
|
||||
Implement the cache storage using rocksdb.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -11,14 +15,14 @@ from dbgpt.core.interface.cache import (
|
||||
RetrievalPolicy,
|
||||
V,
|
||||
)
|
||||
from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem
|
||||
|
||||
from ..base import CacheStorage, StorageItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def db_options(
|
||||
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
|
||||
):
|
||||
def db_options(mem_table_buffer_mb: int = 256, background_threads: int = 2):
|
||||
"""Create rocksdb options."""
|
||||
opt = Options()
|
||||
# create table
|
||||
opt.create_if_missing(True)
|
||||
@@ -42,9 +46,10 @@ def db_options(
|
||||
|
||||
|
||||
class DiskCacheStorage(CacheStorage):
|
||||
def __init__(
|
||||
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
|
||||
) -> None:
|
||||
"""Disk cache storage using rocksdb."""
|
||||
|
||||
def __init__(self, persist_dir: str, mem_table_buffer_mb: int = 256) -> None:
|
||||
"""Create a new instance of DiskCacheStorage."""
|
||||
super().__init__()
|
||||
self.db: Rdict = Rdict(
|
||||
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
|
||||
@@ -55,6 +60,7 @@ class DiskCacheStorage(CacheStorage):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal."""
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
@@ -69,6 +75,7 @@ class DiskCacheStorage(CacheStorage):
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key."""
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
|
||||
# Exact match retrieval
|
||||
@@ -86,6 +93,7 @@ class DiskCacheStorage(CacheStorage):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key."""
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
key_hash = item.key_hash
|
||||
self.db[key_hash] = item.serialize()
|
||||
|
@@ -1,5 +1,3 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.util.memory_utils import _get_object_bytes
|
||||
|
||||
from ..base import StorageItem
|
||||
|
@@ -1 +1,19 @@
|
||||
"""Module of chat history."""
|
||||
|
||||
from .chat_history_db import ( # noqa: F401
|
||||
ChatHistoryDao,
|
||||
ChatHistoryEntity,
|
||||
ChatHistoryMessageEntity,
|
||||
)
|
||||
from .storage_adapter import ( # noqa: F401
|
||||
DBMessageStorageItemAdapter,
|
||||
DBStorageConversationItemAdapter,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"ChatHistoryEntity",
|
||||
"ChatHistoryMessageEntity",
|
||||
"ChatHistoryDao",
|
||||
"DBStorageConversationItemAdapter",
|
||||
"DBMessageStorageItemAdapter",
|
||||
]
|
||||
|
@@ -1,12 +1,15 @@
|
||||
"""Chat history database model."""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from ..metadata import BaseDao, Model
|
||||
|
||||
|
||||
class ChatHistoryEntity(Model):
|
||||
"""Chat history entity."""
|
||||
|
||||
__tablename__ = "chat_history"
|
||||
__table_args__ = (UniqueConstraint("conv_uid", name="uk_conv_uid"),)
|
||||
id = Column(
|
||||
@@ -14,13 +17,16 @@ class ChatHistoryEntity(Model):
|
||||
)
|
||||
conv_uid = Column(
|
||||
String(255),
|
||||
# Change from False to True, the alembic migration will fail, so we use UniqueConstraint to replace it
|
||||
# Change from False to True, the alembic migration will fail, so we use
|
||||
# UniqueConstraint to replace it
|
||||
unique=False,
|
||||
nullable=False,
|
||||
comment="Conversation record unique id",
|
||||
)
|
||||
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
|
||||
summary = Column(String(255), nullable=False, comment="Conversation record summary")
|
||||
summary = Column(
|
||||
Text(length=2**31 - 1), nullable=False, comment="Conversation record summary"
|
||||
)
|
||||
user_name = Column(String(255), nullable=True, comment="interlocutor")
|
||||
messages = Column(
|
||||
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
|
||||
@@ -38,6 +44,8 @@ class ChatHistoryEntity(Model):
|
||||
|
||||
|
||||
class ChatHistoryMessageEntity(Model):
|
||||
"""Chat history message entity."""
|
||||
|
||||
__tablename__ = "chat_history_message"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("conv_uid", "index", name="uk_conversation_message"),
|
||||
@@ -61,9 +69,12 @@ class ChatHistoryMessageEntity(Model):
|
||||
|
||||
|
||||
class ChatHistoryDao(BaseDao):
|
||||
"""Chat history dao."""
|
||||
|
||||
def list_last_20(
|
||||
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||
):
|
||||
"""Retrieve the last 20 chat history records."""
|
||||
session = self.get_raw_session()
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
if user_name:
|
||||
@@ -78,6 +89,7 @@ class ChatHistoryDao(BaseDao):
|
||||
return result
|
||||
|
||||
def raw_update(self, entity: ChatHistoryEntity):
|
||||
"""Update the chat history record."""
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
@@ -87,6 +99,7 @@ class ChatHistoryDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def update_message_by_uid(self, message: str, conv_uid: str):
|
||||
"""Update the chat history record."""
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
@@ -97,7 +110,8 @@ class ChatHistoryDao(BaseDao):
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def raw_delete(self, conv_uid: int):
|
||||
def raw_delete(self, conv_uid: str):
|
||||
"""Delete the chat history record."""
|
||||
if conv_uid is None:
|
||||
raise Exception("conv_uid is None")
|
||||
with self.session() as session:
|
||||
@@ -106,5 +120,6 @@ class ChatHistoryDao(BaseDao):
|
||||
chat_history.delete()
|
||||
|
||||
def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]:
|
||||
"""Retrieve the chat history record by conv_uid."""
|
||||
with self.session(commit=False) as session:
|
||||
return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first()
|
||||
|
@@ -1,5 +1,7 @@
|
||||
"""Adapter for chat history storage."""
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -20,16 +22,23 @@ from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity
|
||||
class DBStorageConversationItemAdapter(
|
||||
StorageItemAdapter[StorageConversation, ChatHistoryEntity]
|
||||
):
|
||||
"""Adapter for chat history storage."""
|
||||
|
||||
def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity:
|
||||
"""Convert to storage format."""
|
||||
message_ids = ",".join(item.message_ids)
|
||||
messages = None
|
||||
if not item.save_message_independent and item.messages:
|
||||
message_dict_list = [_conversation_to_dict(item)]
|
||||
messages = json.dumps(message_dict_list, ensure_ascii=False)
|
||||
summary = item.summary
|
||||
latest_user_message = item.get_latest_user_message()
|
||||
if not summary and latest_user_message is not None:
|
||||
summary = latest_user_message.content
|
||||
return ChatHistoryEntity(
|
||||
conv_uid=item.conv_uid,
|
||||
chat_mode=item.chat_mode,
|
||||
summary=item.summary or item.get_latest_user_message().content,
|
||||
summary=summary,
|
||||
user_name=item.user_name,
|
||||
# We not save messages to chat_history table in new design
|
||||
messages=messages,
|
||||
@@ -38,25 +47,25 @@ class DBStorageConversationItemAdapter(
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation:
|
||||
"""Convert from storage format."""
|
||||
message_ids = model.message_ids.split(",") if model.message_ids else []
|
||||
old_conversations: List[Dict] = (
|
||||
json.loads(model.messages) if model.messages else []
|
||||
json.loads(model.messages) if model.messages else [] # type: ignore
|
||||
)
|
||||
old_messages = []
|
||||
save_message_independent = True
|
||||
if old_conversations:
|
||||
# Load old messages from old conversations, in old design, we save messages to chat_history table
|
||||
# Load old messages from old conversations, in old design, we save messages
|
||||
# to chat_history table
|
||||
save_message_independent = False
|
||||
old_messages: List[BaseMessage] = _parse_old_conversations(
|
||||
old_conversations
|
||||
)
|
||||
old_messages = _parse_old_conversations(old_conversations)
|
||||
return StorageConversation(
|
||||
conv_uid=model.conv_uid,
|
||||
chat_mode=model.chat_mode,
|
||||
summary=model.summary,
|
||||
user_name=model.user_name,
|
||||
conv_uid=model.conv_uid, # type: ignore
|
||||
chat_mode=model.chat_mode, # type: ignore
|
||||
summary=model.summary, # type: ignore
|
||||
user_name=model.user_name, # type: ignore
|
||||
message_ids=message_ids,
|
||||
sys_code=model.sys_code,
|
||||
sys_code=model.sys_code, # type: ignore
|
||||
save_message_independent=save_message_independent,
|
||||
messages=old_messages,
|
||||
)
|
||||
@@ -64,10 +73,11 @@ class DBStorageConversationItemAdapter(
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[ChatHistoryEntity],
|
||||
resource_id: ConversationIdentifier,
|
||||
resource_id: ConversationIdentifier, # type: ignore
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
"""Get query for identifier."""
|
||||
session: Optional[Session] = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
return session.query(ChatHistoryEntity).filter(
|
||||
@@ -78,7 +88,10 @@ class DBStorageConversationItemAdapter(
|
||||
class DBMessageStorageItemAdapter(
|
||||
StorageItemAdapter[MessageStorageItem, ChatHistoryMessageEntity]
|
||||
):
|
||||
"""Adapter for chat history message storage."""
|
||||
|
||||
def to_storage_format(self, item: MessageStorageItem) -> ChatHistoryMessageEntity:
|
||||
"""Convert to storage format."""
|
||||
round_index = item.message_detail.get("round_index", 0)
|
||||
message_detail = json.dumps(item.message_detail, ensure_ascii=False)
|
||||
return ChatHistoryMessageEntity(
|
||||
@@ -91,22 +104,26 @@ class DBMessageStorageItemAdapter(
|
||||
def from_storage_format(
|
||||
self, model: ChatHistoryMessageEntity
|
||||
) -> MessageStorageItem:
|
||||
"""Convert from storage format."""
|
||||
message_detail = (
|
||||
json.loads(model.message_detail) if model.message_detail else {}
|
||||
json.loads(model.message_detail) # type: ignore
|
||||
if model.message_detail
|
||||
else {}
|
||||
)
|
||||
return MessageStorageItem(
|
||||
conv_uid=model.conv_uid,
|
||||
index=model.index,
|
||||
conv_uid=model.conv_uid, # type: ignore
|
||||
index=model.index, # type: ignore
|
||||
message_detail=message_detail,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[ChatHistoryMessageEntity],
|
||||
resource_id: MessageIdentifier,
|
||||
resource_id: MessageIdentifier, # type: ignore
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
"""Get query for identifier."""
|
||||
session: Optional[Session] = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
return session.query(ChatHistoryMessageEntity).filter(
|
||||
|
@@ -1,182 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import duckdb
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
|
||||
|
||||
from ..base import MemoryStoreType
|
||||
|
||||
default_db_path = os.path.join(PILOT_PATH, "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
table_name = "chat_history"
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.DuckDb.value
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
self.connect = duckdb.connect(duckdb_path)
|
||||
self.__init_chat_history_tables()
|
||||
|
||||
def __init_chat_history_tables(self):
|
||||
# 检查表是否存在
|
||||
result = self.connect.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||
).fetchall()
|
||||
|
||||
if not result:
|
||||
# 如果表不存在,则创建新表
|
||||
self.connect.execute(
|
||||
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)"
|
||||
)
|
||||
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||
|
||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||
content = cursor.fetchone()
|
||||
if content:
|
||||
return content[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||
if context:
|
||||
conversations: List[OnceConversation] = json.loads(context)
|
||||
return conversations
|
||||
return []
|
||||
|
||||
def create(self, chat_mode, summary: str, user_name: str) -> None:
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
|
||||
[self.chat_seesion_id, chat_mode, summary, user_name, "", ""],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("init create conversation log error!" + str(e))
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
|
||||
conversations: List[OnceConversation] = []
|
||||
if context:
|
||||
conversations = json.loads(context)
|
||||
conversations.append(_conversation_to_dict(once_message))
|
||||
cursor = self.connect.cursor()
|
||||
if context:
|
||||
cursor.execute(
|
||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
|
||||
[
|
||||
self.chat_seesion_id,
|
||||
once_message.chat_mode,
|
||||
once_message.get_latest_user_message().content,
|
||||
once_message.user_name,
|
||||
once_message.sys_code,
|
||||
json.dumps(conversations, ensure_ascii=False),
|
||||
],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def update(self, messages: List[OnceConversation]) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(messages, ensure_ascii=False), self.chat_seesion_id],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def delete(self) -> bool:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where conv_uid=? ",
|
||||
[conv_uid],
|
||||
)
|
||||
# 获取查询结果字段名
|
||||
fields = [field[0] for field in cursor.description]
|
||||
|
||||
for row in cursor.fetchone():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
return row_dict
|
||||
|
||||
return {}
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
context = cursor.fetchone()
|
||||
if context:
|
||||
if context[0]:
|
||||
return json.loads(context[0])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def conv_list(
|
||||
user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
query = "SELECT * FROM chat_history"
|
||||
params = []
|
||||
conditions = []
|
||||
if user_name:
|
||||
conditions.append("user_name = ?")
|
||||
params.append(user_name)
|
||||
if sys_code:
|
||||
conditions.append("sys_code = ?")
|
||||
params.append(sys_code)
|
||||
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
query += " ORDER BY id DESC LIMIT 20"
|
||||
cursor.execute(query, params)
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
|
||||
return data
|
||||
|
||||
return []
|
@@ -1,50 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core.interface.message import (
|
||||
OnceConversation,
|
||||
_conversation_from_dict,
|
||||
_conversations_to_dict,
|
||||
)
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.File.value
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
now = datetime.datetime.now()
|
||||
date_string = now.strftime("%Y%m%d")
|
||||
path: str = f"{CFG.message_dir}/{date_string}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
dir_path = Path(path)
|
||||
self.file_path = Path(dir_path / f"{chat_session_id}.json")
|
||||
if not self.file_path.exists():
|
||||
self.file_path.touch()
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
items = json.loads(self.file_path.read_text())
|
||||
history: List[OnceConversation] = []
|
||||
for onece in items:
|
||||
messages = _conversation_from_dict(onece)
|
||||
history.append(messages)
|
||||
return history
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
historys = self.messages()
|
||||
historys.append(once_message)
|
||||
self.file_path.write_text(
|
||||
json.dumps(_conversations_to_dict(historys), ensure_ascii=False, indent=4),
|
||||
encoding="UTF-8",
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.file_path.write_text(json.dumps([]))
|
@@ -1,27 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
|
||||
from dbgpt.util.custom_data_structure import FixedSizeDict
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.Memory.value
|
||||
|
||||
histroies_map = FixedSizeDict(100)
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
self.histroies_map.update({chat_session_id: []})
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
return self.histroies_map.get(self.chat_seesion_id)
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
self.histroies_map.get(self.chat_seesion_id).append(once_message)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.histroies_map.pop(self.chat_seesion_id)
|
@@ -1,6 +1,7 @@
|
||||
from dbgpt.storage.metadata._base_dao import BaseDao
|
||||
from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
"""Module for handling metadata storage."""
|
||||
from dbgpt.storage.metadata._base_dao import BaseDao # noqa: F401
|
||||
from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory # noqa: F401
|
||||
from dbgpt.storage.metadata.db_manager import ( # noqa: F401
|
||||
BaseModel,
|
||||
DatabaseManager,
|
||||
Model,
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
@@ -44,6 +44,7 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
self,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
) -> None:
|
||||
"""Create a BaseDao instance."""
|
||||
self._db_manager = db_manager or db
|
||||
|
||||
def get_raw_session(self) -> Session:
|
||||
@@ -52,9 +53,7 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
Your should commit or rollback the session manually.
|
||||
We suggest you use :meth:`session` instead.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
user = User(name="Edward Snowden")
|
||||
@@ -63,13 +62,14 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
session.commit()
|
||||
session.close()
|
||||
"""
|
||||
return self._db_manager._session()
|
||||
return self._db_manager._session() # type: ignore
|
||||
|
||||
@contextmanager
|
||||
def session(self, commit: Optional[bool] = True) -> Session:
|
||||
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
|
||||
"""Provide a transactional scope around a series of operations.
|
||||
|
||||
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
|
||||
If raise an exception, the session will be roll back automatically, otherwise
|
||||
it will be committed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -78,7 +78,8 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
session.query(User).filter(User.name == "Edward Snowden").first()
|
||||
|
||||
Args:
|
||||
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
|
||||
commit (Optional[bool], optional): Whether to commit the session. Defaults
|
||||
to True.
|
||||
|
||||
Returns:
|
||||
Session: A session object.
|
||||
@@ -147,7 +148,8 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
session.add(entry)
|
||||
req = self.to_request(entry)
|
||||
session.commit()
|
||||
return self.get_one(req)
|
||||
res = self.get_one(req)
|
||||
return res # type: ignore
|
||||
|
||||
def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES:
|
||||
"""Update an entity object.
|
||||
@@ -163,11 +165,14 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
entry = query.first()
|
||||
if entry is None:
|
||||
raise Exception("Invalid request")
|
||||
for key, value in update_request.dict().items():
|
||||
for key, value in update_request.dict().items(): # type: ignore
|
||||
if value is not None:
|
||||
setattr(entry, key, value)
|
||||
session.merge(entry)
|
||||
return self.get_one(self.to_request(entry))
|
||||
res = self.get_one(self.to_request(entry))
|
||||
if not res:
|
||||
raise Exception("Update failed")
|
||||
return res
|
||||
|
||||
def delete(self, query_request: QUERY_SPEC) -> None:
|
||||
"""Delete an entity object.
|
||||
@@ -179,7 +184,8 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
result_list = self._get_entity_list(session, query_request)
|
||||
if len(result_list) != 1:
|
||||
raise ValueError(
|
||||
f"Delete request should return one result, but got {len(result_list)}"
|
||||
f"Delete request should return one result, but got "
|
||||
f"{len(result_list)}"
|
||||
)
|
||||
session.delete(result_list[0])
|
||||
|
||||
@@ -241,11 +247,11 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
query = self._create_query_object(session, query_request)
|
||||
total_count = query.count()
|
||||
items = query.offset((page - 1) * page_size).limit(page_size)
|
||||
items = [self.to_response(item) for item in items]
|
||||
res_items = [self.to_response(item) for item in items]
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginationResult(
|
||||
items=items,
|
||||
items=res_items,
|
||||
total_count=total_count,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
@@ -273,4 +279,4 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
if isinstance(value, (list, tuple, dict, set)):
|
||||
continue
|
||||
query = query.filter(getattr(model_cls, key) == value)
|
||||
return query
|
||||
return query # type: ignore
|
||||
|
@@ -1,19 +1,26 @@
|
||||
"""UnifiedDBManagerFactory is a factory class to create a DatabaseManager instance."""
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from .db_manager import DatabaseManager
|
||||
|
||||
|
||||
class UnifiedDBManagerFactory(BaseComponent):
|
||||
"""UnfiedDBManagerFactory class."""
|
||||
|
||||
name = ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY
|
||||
"""The name of the factory."""
|
||||
|
||||
def __init__(self, system_app: SystemApp, db_manager: DatabaseManager):
|
||||
"""Create a UnifiedDBManagerFactory instance."""
|
||||
super().__init__(system_app)
|
||||
self._db_manager = db_manager
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the factory with the system app."""
|
||||
pass
|
||||
|
||||
def create(self) -> DatabaseManager:
|
||||
"""Create a DatabaseManager instance."""
|
||||
if not self._db_manager:
|
||||
raise RuntimeError("db_manager is not initialized")
|
||||
if not self._db_manager.is_initialized:
|
||||
|
@@ -1,8 +1,9 @@
|
||||
"""The database manager."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union
|
||||
from typing import ClassVar, Dict, Generic, Iterator, Optional, Type, TypeVar, Union
|
||||
|
||||
from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm
|
||||
from sqlalchemy.orm import (
|
||||
@@ -21,19 +22,20 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class _QueryObject:
|
||||
"""The query object."""
|
||||
|
||||
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||
return model_cls.query_class(
|
||||
model_cls, session=model_cls.__db_manager__._session()
|
||||
)
|
||||
# class _QueryObject:
|
||||
# """The query object."""
|
||||
#
|
||||
# def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||
# return model_cls.query_class(
|
||||
# model_cls, session=model_cls.__db_manager__._session()
|
||||
# )
|
||||
#
|
||||
|
||||
|
||||
class BaseQuery(orm.Query):
|
||||
def paginate_query(
|
||||
self, page: Optional[int] = 1, per_page: Optional[int] = 20
|
||||
) -> PaginationResult:
|
||||
"""Base query class."""
|
||||
|
||||
def paginate_query(self, page: int = 1, per_page: int = 20) -> PaginationResult:
|
||||
"""Paginate the query.
|
||||
|
||||
Example:
|
||||
@@ -56,10 +58,10 @@ class BaseQuery(orm.Query):
|
||||
)
|
||||
print(pagination)
|
||||
|
||||
|
||||
Args:
|
||||
page (Optional[int], optional): The page number. Defaults to 1.
|
||||
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
|
||||
per_page (Optional[int], optional): The number of items per page. Defaults
|
||||
to 20.
|
||||
Returns:
|
||||
PaginationResult: The pagination result.
|
||||
"""
|
||||
@@ -100,7 +102,6 @@ class DatabaseManager:
|
||||
"""The database manager.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
@@ -161,6 +162,7 @@ class DatabaseManager:
|
||||
Query = BaseQuery
|
||||
|
||||
def __init__(self):
|
||||
"""Create a DatabaseManager."""
|
||||
self._db_url = None
|
||||
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
|
||||
self._engine: Optional[Engine] = None
|
||||
@@ -169,12 +171,12 @@ class DatabaseManager:
|
||||
@property
|
||||
def Model(self) -> _Model:
|
||||
"""Get the declarative base."""
|
||||
return self._base
|
||||
return self._base # type: ignore
|
||||
|
||||
@property
|
||||
def metadata(self) -> MetaData:
|
||||
"""Get the metadata."""
|
||||
return self.Model.metadata
|
||||
return self.Model.metadata # type: ignore
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
@@ -183,11 +185,11 @@ class DatabaseManager:
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Whether the database manager is initialized.""" ""
|
||||
"""Whether the database manager is initialized."""
|
||||
return self._engine is not None and self._session is not None
|
||||
|
||||
@contextmanager
|
||||
def session(self, commit: Optional[bool] = True) -> Session:
|
||||
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
|
||||
"""Get the session with context manager.
|
||||
|
||||
This context manager handles the lifecycle of a SQLAlchemy session.
|
||||
@@ -199,7 +201,6 @@ class DatabaseManager:
|
||||
read and write operations.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# For write operations (insert, update, delete):
|
||||
@@ -211,7 +212,8 @@ class DatabaseManager:
|
||||
# For read-only operations:
|
||||
with db.session(commit=False) as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
# session.commit() is NOT called, as it's unnecessary for read operations
|
||||
# session.commit() is NOT called, as it's unnecessary for read
|
||||
# operations
|
||||
|
||||
Args:
|
||||
commit (Optional[bool], optional): Whether to commit the session.
|
||||
@@ -237,16 +239,15 @@ class DatabaseManager:
|
||||
with the ORM object are complete.
|
||||
c. Re-bind the instance to a new session if further interaction
|
||||
is required after the session is closed.
|
||||
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("The database manager is not initialized.")
|
||||
session = self._session()
|
||||
session = self._session() # type: ignore
|
||||
try:
|
||||
yield session
|
||||
if commit:
|
||||
session.commit()
|
||||
except:
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
@@ -266,10 +267,10 @@ class DatabaseManager:
|
||||
if not isinstance(model, DeclarativeMeta):
|
||||
model = declarative_base(cls=model, name="Model")
|
||||
if not getattr(model, "query_class", None):
|
||||
model.query_class = self.Query
|
||||
model.query_class = self.Query # type: ignore
|
||||
# model.query = _QueryObject()
|
||||
model.__db_manager__ = self
|
||||
return model
|
||||
model.__db_manager__ = self # type: ignore
|
||||
return model # type: ignore
|
||||
|
||||
def init_db(
|
||||
self,
|
||||
@@ -284,11 +285,14 @@ class DatabaseManager:
|
||||
|
||||
Args:
|
||||
db_url (Union[str, URL]): The database url.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
|
||||
None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
override_query_class (Optional[bool], optional): Whether to override the
|
||||
query class. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults
|
||||
to None.
|
||||
"""
|
||||
if session_options is None:
|
||||
session_options = {}
|
||||
@@ -309,8 +313,8 @@ class DatabaseManager:
|
||||
session_options.setdefault("query_cls", self.Query)
|
||||
session_factory = sessionmaker(bind=self._engine, **session_options)
|
||||
# self._session = scoped_session(session_factory)
|
||||
self._session = session_factory
|
||||
self._base.metadata.bind = self._engine
|
||||
self._session = session_factory # type: ignore
|
||||
self._base.metadata.bind = self._engine # type: ignore
|
||||
|
||||
def init_default_db(
|
||||
self,
|
||||
@@ -333,24 +337,28 @@ class DatabaseManager:
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
"""
|
||||
if not engine_args:
|
||||
engine_args = {}
|
||||
# Pool class
|
||||
engine_args["poolclass"] = QueuePool
|
||||
# The number of connections to keep open inside the connection pool.
|
||||
engine_args["pool_size"] = 10
|
||||
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
|
||||
# pool_size).
|
||||
engine_args["max_overflow"] = 20
|
||||
# The number of seconds to wait before giving up on getting a connection from the pool.
|
||||
engine_args["pool_timeout"] = 30
|
||||
# Recycle the connection if it has been idle for this many seconds.
|
||||
engine_args["pool_recycle"] = 3600
|
||||
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
|
||||
engine_args["pool_pre_ping"] = True
|
||||
engine_args = {
|
||||
# Pool class
|
||||
"poolclass": QueuePool,
|
||||
# The number of connections to keep open inside the connection pool.
|
||||
"pool_size": 10,
|
||||
# The maximum overflow size of the pool when the number of connections
|
||||
# be used in the pool is exceeded(pool_size).
|
||||
"max_overflow": 20,
|
||||
# The number of seconds to wait before giving up on getting a connection
|
||||
# from the pool.
|
||||
"pool_timeout": 30,
|
||||
# Recycle the connection if it has been idle for this many seconds.
|
||||
"pool_recycle": 3600,
|
||||
# Enable the connection pool “pre-ping” feature that tests connections
|
||||
# for liveness upon each checkout.
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
|
||||
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
|
||||
|
||||
def create_all(self):
|
||||
"""Create all tables."""
|
||||
self.Model.metadata.create_all(self._engine)
|
||||
|
||||
@staticmethod
|
||||
@@ -364,9 +372,7 @@ class DatabaseManager:
|
||||
"""Build the database manager from the db_url_or_db.
|
||||
|
||||
Examples:
|
||||
|
||||
Build from the database url.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
@@ -389,16 +395,19 @@ class DatabaseManager:
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
|
||||
Args:
|
||||
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the
|
||||
database manager.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
|
||||
None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
override_query_class (Optional[bool], optional): Whether to override the
|
||||
query class. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||
if isinstance(db_url_or_db, (str, URL)):
|
||||
db_manager = DatabaseManager()
|
||||
db_manager.init_db(
|
||||
db_url_or_db, engine_args, base, query_class, override_query_class
|
||||
@@ -408,7 +417,8 @@ class DatabaseManager:
|
||||
return db_url_or_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got "
|
||||
f"{type(db_url_or_db)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -422,7 +432,6 @@ Examples:
|
||||
>>> with db.session() as session:
|
||||
... session.query(...)
|
||||
...
|
||||
|
||||
>>> from dbgpt.storage.metadata import db, Model
|
||||
>>> from urllib.parse import quote_plus as urlquote, quote
|
||||
>>> db_name = "dbgpt"
|
||||
@@ -430,7 +439,10 @@ Examples:
|
||||
>>> db_port = 3306
|
||||
>>> user = "root"
|
||||
>>> password = "123456"
|
||||
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
|
||||
>>> url = (
|
||||
... f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}"
|
||||
... f":{str(db_port)}/{db_name}"
|
||||
... )
|
||||
>>> engine_args = {
|
||||
... "pool_size": 10,
|
||||
... "max_overflow": 20,
|
||||
@@ -460,18 +472,21 @@ class BaseCRUDMixin(Generic[T]):
|
||||
@classmethod
|
||||
def db(cls) -> DatabaseManager:
|
||||
"""Get the database manager."""
|
||||
return cls.__db_manager__
|
||||
return cls.__db_manager__ # type: ignore
|
||||
|
||||
|
||||
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||
"""The base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
"""Whether the model is abstract."""
|
||||
|
||||
|
||||
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
|
||||
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
|
||||
"""Create a model."""
|
||||
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]): # type: ignore
|
||||
"""Mixin that adds convenience methods for CRUD."""
|
||||
|
||||
_db_manager: DatabaseManager = db_manager
|
||||
|
||||
@@ -485,7 +500,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
"""Get the database manager."""
|
||||
return cls._db_manager
|
||||
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): # type: ignore
|
||||
"""Base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -493,7 +508,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
return _NewModel
|
||||
|
||||
|
||||
Model = create_model(db)
|
||||
Model: Type = create_model(db)
|
||||
|
||||
|
||||
def initialize_db(
|
||||
@@ -511,8 +526,10 @@ def initialize_db(
|
||||
db_name (str): The database name.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the
|
||||
database. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to
|
||||
None.
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""Database storage implementation using SQLAlchemy."""
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from typing import Dict, Iterator, List, Optional, Type, Union
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.orm import DeclarativeMeta, Session
|
||||
@@ -17,8 +18,8 @@ from .db_manager import BaseModel, BaseQuery, DatabaseManager
|
||||
|
||||
|
||||
def _copy_public_properties(src: BaseModel, dest: BaseModel):
|
||||
"""Simple copy public properties from src to dest"""
|
||||
for column in src.__table__.columns:
|
||||
"""Copy public properties from src to dest."""
|
||||
for column in src.__table__.columns: # type: ignore
|
||||
if column.name != "id":
|
||||
value = getattr(src, column.name)
|
||||
if value is not None:
|
||||
@@ -26,6 +27,8 @@ def _copy_public_properties(src: BaseModel, dest: BaseModel):
|
||||
|
||||
|
||||
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
"""Database storage implementation using SQLAlchemy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager],
|
||||
@@ -36,6 +39,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
):
|
||||
"""Create a SQLAlchemyStorage instance."""
|
||||
super().__init__(serializer=serializer, adapter=adapter)
|
||||
self.db_manager = DatabaseManager.build_from(
|
||||
db_url_or_db, engine_args, base, query_class
|
||||
@@ -43,16 +47,19 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
self._model_class = model_class
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
def session(self) -> Iterator[Session]:
|
||||
"""Return a session."""
|
||||
with self.db_manager.session() as session:
|
||||
yield session
|
||||
|
||||
def save(self, data: T) -> None:
|
||||
"""Save data to the storage."""
|
||||
with self.session() as session:
|
||||
model_instance = self.adapter.to_storage_format(data)
|
||||
session.add(model_instance)
|
||||
|
||||
def update(self, data: T) -> None:
|
||||
"""Update data in the storage."""
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, data.identifier, session=session
|
||||
@@ -66,6 +73,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
return
|
||||
|
||||
def save_or_update(self, data: T) -> None:
|
||||
"""Save or update data in the storage."""
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, data.identifier, session=session
|
||||
@@ -79,6 +87,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
self.save(data)
|
||||
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
"""Load data by identifier from the storage."""
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, resource_id, session=session
|
||||
@@ -89,6 +98,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
return None
|
||||
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
"""Delete data by identifier from the storage."""
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, resource_id, session=session
|
||||
|
@@ -1,14 +1,21 @@
|
||||
"""Database information class and database type enumeration."""
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DbInfo:
|
||||
"""Database information class."""
|
||||
|
||||
def __init__(self, name, is_file_db: bool = False):
|
||||
"""Create a new instance of DbInfo."""
|
||||
self.name = name
|
||||
self.is_file_db = is_file_db
|
||||
|
||||
|
||||
class DBType(Enum):
|
||||
"""Database type enumeration."""
|
||||
|
||||
Mysql = DbInfo("mysql")
|
||||
OCeanBase = DbInfo("oceanbase")
|
||||
DuckDb = DbInfo("duckdb", True)
|
||||
@@ -22,14 +29,24 @@ class DBType(Enum):
|
||||
Doris = DbInfo("doris")
|
||||
Hive = DbInfo("hive")
|
||||
|
||||
def value(self):
|
||||
def value(self) -> str:
|
||||
"""Return the name of the database type."""
|
||||
return self._value_.name
|
||||
|
||||
def is_file_db(self):
|
||||
def is_file_db(self) -> bool:
|
||||
"""Return whether the database is a file database."""
|
||||
return self._value_.is_file_db
|
||||
|
||||
@staticmethod
|
||||
def of_db_type(db_type: str):
|
||||
def of_db_type(db_type: str) -> Optional["DBType"]:
|
||||
"""Return the database type of the given name.
|
||||
|
||||
Args:
|
||||
db_type (str): The name of the database type.
|
||||
|
||||
Returns:
|
||||
Optional[DBType]: The database type of the given name.
|
||||
"""
|
||||
for item in DBType:
|
||||
if item.value() == db_type:
|
||||
return item
|
||||
@@ -37,7 +54,7 @@ class DBType(Enum):
|
||||
|
||||
@staticmethod
|
||||
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
|
||||
"""Parse out the database name of the embedded database from the file path"""
|
||||
"""Parse out the database name of the embedded database from the file path."""
|
||||
base_name = os.path.basename(local_db_path)
|
||||
db_name = os.path.splitext(base_name)[0]
|
||||
if "." in db_name:
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Vector Store Module."""
|
||||
from typing import Any
|
||||
|
||||
|
||||
|
@@ -1,12 +1,13 @@
|
||||
"""Vector store base class."""
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -15,6 +16,11 @@ logger = logging.getLogger(__name__)
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
name: str = Field(
|
||||
default="dbgpt_collection",
|
||||
description="The name of vector store, if not set, will use the default name.",
|
||||
@@ -28,7 +34,7 @@ class VectorStoreConfig(BaseModel):
|
||||
description="The password of vector store, if not set, will use the default "
|
||||
"password.",
|
||||
)
|
||||
embedding_fn: Optional[Any] = Field(
|
||||
embedding_fn: Optional[Embeddings] = Field(
|
||||
default=None,
|
||||
description="The embedding function of vector store, if not set, will use the "
|
||||
"default embedding function.",
|
||||
@@ -47,27 +53,31 @@ class VectorStoreConfig(BaseModel):
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
"""base class for vector store database"""
|
||||
"""Vector store base class."""
|
||||
|
||||
@abstractmethod
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""load document in vector database.
|
||||
"""Load document in vector database.
|
||||
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
chunks(List[Chunk]): document chunks.
|
||||
|
||||
Return:
|
||||
- ids: chunks ids.
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_document_with_limit(
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
|
||||
) -> List[str]:
|
||||
"""load document in vector database with limit.
|
||||
"""Load document in vector database with specified limit.
|
||||
|
||||
Args:
|
||||
chunks: document chunks.
|
||||
max_chunks_once_load: Max number of chunks to load at once.
|
||||
max_threads: Max number of threads to use.
|
||||
chunks(List[Chunk]): Document chunks.
|
||||
max_chunks_once_load(int): Max number of chunks to load at once.
|
||||
max_threads(int): Max number of threads to use.
|
||||
|
||||
Return:
|
||||
List[str]: Chunk ids.
|
||||
"""
|
||||
# Group the chunks into chunks of size max_chunks
|
||||
chunk_groups = [
|
||||
@@ -96,13 +106,15 @@ class VectorStoreBase(ABC):
|
||||
return ids
|
||||
|
||||
@abstractmethod
|
||||
def similar_search(self, text, topk) -> List[Chunk]:
|
||||
"""similar search in vector database.
|
||||
def similar_search(self, text: str, topk: int) -> List[Chunk]:
|
||||
"""Similar search in vector database.
|
||||
|
||||
Args:
|
||||
- text: query text
|
||||
- topk: topk
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -110,38 +122,43 @@ class VectorStoreBase(ABC):
|
||||
def similar_search_with_scores(
|
||||
self, text, topk, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""similar search in vector database with scores.
|
||||
"""Similar search with scores in vector database.
|
||||
|
||||
Args:
|
||||
- text: query text
|
||||
- topk: topk
|
||||
- score_threshold: score_threshold: Optional, a floating point value between 0 to 1
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
score_threshold(int): score_threshold: Optional, a floating point value
|
||||
between 0 to 1
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""is vector store name exist."""
|
||||
"""Whether vector name exists."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids):
|
||||
"""delete vector by ids.
|
||||
def delete_by_ids(self, ids: str):
|
||||
"""Delete vectors by ids.
|
||||
|
||||
Args:
|
||||
- ids: vector ids
|
||||
ids(str): The ids of vectors to delete, separated by comma.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_vector_name(self, vector_name):
|
||||
"""delete vector name.
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector by name.
|
||||
|
||||
Args:
|
||||
- vector_name: vector store name
|
||||
vector_name(str): The name of vector to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _normalization_vectors(self, vectors):
|
||||
"""normalization vectors to scale[0,1]"""
|
||||
"""Return L2-normalization vectors to scale[0,1].
|
||||
|
||||
Normalization vectors to scale[0,1].
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
norm = np.linalg.norm(vectors)
|
||||
|
@@ -1,14 +1,18 @@
|
||||
"""Chroma vector store."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
|
||||
# TODO: Recycle dependency on rag and storage
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
from .base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,20 +20,28 @@ logger = logging.getLogger(__name__)
|
||||
class ChromaVectorConfig(VectorStoreConfig):
|
||||
"""Chroma vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
persist_path: str = Field(
|
||||
default=os.getenv("CHROMA_PERSIST_PATH", None),
|
||||
description="The password of vector store, if not set, will use the default password.",
|
||||
description="The password of vector store, if not set, will use the default "
|
||||
"password.",
|
||||
)
|
||||
collection_metadata: dict = Field(
|
||||
default=None,
|
||||
description="the index metadata of vector store, if not set, will use the default metadata.",
|
||||
description="the index metadata of vector store, if not set, will use the "
|
||||
"default metadata.",
|
||||
)
|
||||
|
||||
|
||||
class ChromaStore(VectorStoreBase):
|
||||
"""chroma database"""
|
||||
"""Chroma vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
|
||||
"""Create a ChromaStore instance."""
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
chroma_vector_config = vector_store_config.dict()
|
||||
@@ -59,6 +71,7 @@ class ChromaStore(VectorStoreBase):
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
|
||||
"""Search similar documents."""
|
||||
logger.info("ChromaStore similar search")
|
||||
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
|
||||
return [
|
||||
@@ -67,14 +80,16 @@ class ChromaStore(VectorStoreBase):
|
||||
]
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
||||
"""
|
||||
"""Search similar documents with scores.
|
||||
|
||||
Chroma similar_search_with_score.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
Args:
|
||||
text(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value
|
||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||
dissimilar, 1 is most similar.
|
||||
"""
|
||||
logger.info("ChromaStore similar search with scores")
|
||||
docs_and_scores = (
|
||||
@@ -87,8 +102,8 @@ class ChromaStore(VectorStoreBase):
|
||||
for doc, score in docs_and_scores
|
||||
]
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""is vector store name exist."""
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Whether vector name exists."""
|
||||
logger.info(f"Check persist_dir: {self.persist_dir}")
|
||||
if not os.path.exists(self.persist_dir):
|
||||
return False
|
||||
@@ -98,6 +113,7 @@ class ChromaStore(VectorStoreBase):
|
||||
return len(files) > 0
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document to vector store."""
|
||||
logger.info("ChromaStore load document")
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
metadatas = [chunk.metadata for chunk in chunks]
|
||||
@@ -105,14 +121,16 @@ class ChromaStore(VectorStoreBase):
|
||||
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return ids
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector name."""
|
||||
logger.info(f"chroma vector_name:{vector_name} begin delete...")
|
||||
self.vector_store_client.delete_collection()
|
||||
self._clean_persist_folder()
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
logger.info(f"begin delete chroma ids...")
|
||||
"""Delete vector by ids."""
|
||||
logger.info(f"begin delete chroma ids: {ids}")
|
||||
ids = ids.split(",")
|
||||
if len(ids) > 0:
|
||||
collection = self.vector_store_client._collection
|
||||
|
@@ -1,19 +1,26 @@
|
||||
"""Connector for vector store."""
|
||||
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
connector = {}
|
||||
connector: Dict[str, Type] = {}
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
"""The connector for vector store.
|
||||
|
||||
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||
2.similar_search: similarity search from vector_store
|
||||
3.similar_search_with_scores: similarity search with similarity score from vector_store
|
||||
VectorStoreConnector, can connect different vector db provided load document api_v1
|
||||
and similar search api_v1.
|
||||
|
||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus,
|
||||
Weaviate).
|
||||
2.similar_search: similarity search from vector_store.
|
||||
3.similar_search_with_scores: similarity search with similarity score from
|
||||
vector_store
|
||||
|
||||
code example:
|
||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -23,9 +30,12 @@ class VectorStoreConnector:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, vector_store_type: str, vector_store_config: VectorStoreConfig = None
|
||||
self,
|
||||
vector_store_type: str,
|
||||
vector_store_config: Optional[VectorStoreConfig] = None,
|
||||
) -> None:
|
||||
"""initialize vector store connector.
|
||||
"""Create a VectorStoreConnector instance.
|
||||
|
||||
Args:
|
||||
- vector_store_type: vector store type Milvus, Chroma, Weaviate
|
||||
- ctx: vector store config params.
|
||||
@@ -34,7 +44,7 @@ class VectorStoreConnector:
|
||||
self._register()
|
||||
|
||||
if self._match(vector_store_type):
|
||||
self.connector_class = connector.get(vector_store_type)
|
||||
self.connector_class = connector[vector_store_type]
|
||||
else:
|
||||
raise Exception(f"Vector Store Type Not support. {0}", vector_store_type)
|
||||
|
||||
@@ -44,11 +54,11 @@ class VectorStoreConnector:
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
vector_store_type: str = None,
|
||||
vector_store_type: Optional[str] = None,
|
||||
embedding_fn: Optional[Any] = None,
|
||||
vector_store_config: Optional[VectorStoreConfig] = None,
|
||||
) -> "VectorStoreConnector":
|
||||
"""initialize default vector store connector."""
|
||||
"""Initialize default vector store connector."""
|
||||
vector_store_type = vector_store_type or os.getenv(
|
||||
"VECTOR_STORE_TYPE", "Chroma"
|
||||
)
|
||||
@@ -56,22 +66,33 @@ class VectorStoreConnector:
|
||||
|
||||
vector_store_config = vector_store_config or ChromaVectorConfig()
|
||||
vector_store_config.embedding_fn = embedding_fn
|
||||
return cls(vector_store_type, vector_store_config)
|
||||
real_vector_store_type = cast(str, vector_store_type)
|
||||
return cls(real_vector_store_type, vector_store_config)
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""load document in vector database.
|
||||
"""Load document in vector database.
|
||||
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
Return chunk ids.
|
||||
"""
|
||||
max_chunks_once_load = (
|
||||
self._vector_store_config.max_chunks_once_load
|
||||
if self._vector_store_config
|
||||
else 10
|
||||
)
|
||||
max_threads = (
|
||||
self._vector_store_config.max_threads if self._vector_store_config else 1
|
||||
)
|
||||
return self.client.load_document_with_limit(
|
||||
chunks,
|
||||
self._vector_store_config.max_chunks_once_load,
|
||||
self._vector_store_config.max_threads,
|
||||
max_chunks_once_load,
|
||||
max_threads,
|
||||
)
|
||||
|
||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||
"""similar search in vector database.
|
||||
"""Similar search in vector database.
|
||||
|
||||
Args:
|
||||
- doc: query text
|
||||
- topk: topk
|
||||
@@ -83,14 +104,17 @@ class VectorStoreConnector:
|
||||
def similar_search_with_scores(
|
||||
self, doc: str, topk: int, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""
|
||||
"""Similar search with scores in vector database.
|
||||
|
||||
similar_search_with_score in vector database..
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
Args:
|
||||
- doc(str): query text
|
||||
- topk(int): return docs nums. Defaults to 4.
|
||||
- score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
doc(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value
|
||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||
dissimilar, 1 is most similar.
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
@@ -98,32 +122,33 @@ class VectorStoreConnector:
|
||||
|
||||
@property
|
||||
def vector_store_config(self) -> VectorStoreConfig:
|
||||
"""vector store config."""
|
||||
"""Return the vector store config."""
|
||||
if not self._vector_store_config:
|
||||
raise ValueError("vector store config not set.")
|
||||
return self._vector_store_config
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""is vector store name exist."""
|
||||
"""Whether vector name exists."""
|
||||
return self.client.vector_name_exists()
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
"""vector store delete
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector name.
|
||||
|
||||
Args:
|
||||
- vector_name: vector store name
|
||||
"""
|
||||
return self.client.delete_vector_name(vector_name)
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""vector store delete by ids.
|
||||
"""Delete vector by ids.
|
||||
|
||||
Args:
|
||||
- ids: vector ids
|
||||
"""
|
||||
return self.client.delete_by_ids(ids=ids)
|
||||
|
||||
def _match(self, vector_store_type) -> bool:
|
||||
if connector.get(vector_store_type):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return bool(connector.get(vector_store_type))
|
||||
|
||||
def _register(self):
|
||||
for cls in vector_store.__all__:
|
||||
|
@@ -1,13 +1,14 @@
|
||||
"""Milvus vector store."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt.rag.chunk import Chunk, Document
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt.util import string_utils
|
||||
|
||||
@@ -17,6 +18,11 @@ logger = logging.getLogger(__name__)
|
||||
class MilvusVectorConfig(VectorStoreConfig):
|
||||
"""Milvus vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
uri: str = Field(
|
||||
default="localhost",
|
||||
description="The uri of milvus store, if not set, will use the default uri.",
|
||||
@@ -28,7 +34,8 @@ class MilvusVectorConfig(VectorStoreConfig):
|
||||
|
||||
alias: str = Field(
|
||||
default="default",
|
||||
description="The alias of milvus store, if not set, will use the default alias.",
|
||||
description="The alias of milvus store, if not set, will use the default "
|
||||
"alias.",
|
||||
)
|
||||
user: str = Field(
|
||||
default=None,
|
||||
@@ -36,35 +43,42 @@ class MilvusVectorConfig(VectorStoreConfig):
|
||||
)
|
||||
password: str = Field(
|
||||
default=None,
|
||||
description="The password of milvus store, if not set, will use the default password.",
|
||||
description="The password of milvus store, if not set, will use the default "
|
||||
"password.",
|
||||
)
|
||||
primary_field: str = Field(
|
||||
default="pk_id",
|
||||
description="The primary field of milvus store, if not set, will use the default primary field.",
|
||||
description="The primary field of milvus store, if not set, will use the "
|
||||
"default primary field.",
|
||||
)
|
||||
text_field: str = Field(
|
||||
default="content",
|
||||
description="The text field of milvus store, if not set, will use the default text field.",
|
||||
description="The text field of milvus store, if not set, will use the default "
|
||||
"text field.",
|
||||
)
|
||||
embedding_field: str = Field(
|
||||
default="vector",
|
||||
description="The embedding field of milvus store, if not set, will use the default embedding field.",
|
||||
description="The embedding field of milvus store, if not set, will use the "
|
||||
"default embedding field.",
|
||||
)
|
||||
metadata_field: str = Field(
|
||||
default="metadata",
|
||||
description="The metadata field of milvus store, if not set, will use the default metadata field.",
|
||||
description="The metadata field of milvus store, if not set, will use the "
|
||||
"default metadata field.",
|
||||
)
|
||||
secure: str = Field(
|
||||
default="",
|
||||
description="The secure of milvus store, if not set, will use the default secure.",
|
||||
description="The secure of milvus store, if not set, will use the default "
|
||||
"secure.",
|
||||
)
|
||||
|
||||
|
||||
class MilvusStore(VectorStoreBase):
|
||||
"""Milvus database"""
|
||||
"""Milvus vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
|
||||
"""MilvusStore init.
|
||||
"""Create a MilvusStore instance.
|
||||
|
||||
Args:
|
||||
vector_store_config (MilvusVectorConfig): MilvusStore config.
|
||||
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
||||
@@ -93,8 +107,11 @@ class MilvusStore(VectorStoreBase):
|
||||
hex_str = bytes_str.hex()
|
||||
self.collection_name = hex_str
|
||||
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.fields = []
|
||||
if not vector_store_config.embedding_fn:
|
||||
raise ValueError("embedding is required for MilvusStore")
|
||||
|
||||
self.embedding: Embeddings = vector_store_config.embedding_fn
|
||||
self.fields: List = []
|
||||
self.alias = milvus_vector_config.get("alias") or "default"
|
||||
|
||||
# use HNSW by default.
|
||||
@@ -124,7 +141,8 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
raise ValueError(
|
||||
"Both username and password must be set to use authentication for Milvus"
|
||||
"Both username and password must be set to use authentication for "
|
||||
"Milvus"
|
||||
)
|
||||
if self.username:
|
||||
connect_kwargs["user"] = self.username
|
||||
@@ -139,7 +157,10 @@ class MilvusStore(VectorStoreBase):
|
||||
)
|
||||
|
||||
def init_schema_and_load(self, vector_name, documents) -> List[str]:
|
||||
"""Create a Milvus collection, indexes it with HNSW, load document.
|
||||
"""Create a Milvus collection.
|
||||
|
||||
Create a Milvus collection, indexes it with HNSW, load document.
|
||||
|
||||
Args:
|
||||
vector_name (Embeddings): your collection name.
|
||||
documents (List[str]): Text to insert.
|
||||
@@ -155,7 +176,7 @@ class MilvusStore(VectorStoreBase):
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
from pymilvus.orm.types import infer_dtype_bydata # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
@@ -240,10 +261,10 @@ class MilvusStore(VectorStoreBase):
|
||||
partition_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""add text data into Milvus."""
|
||||
"""Add text data into Milvus."""
|
||||
insert_dict: Any = {self.text_field: list(texts)}
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np # noqa: F401
|
||||
|
||||
text_vector = self.embedding.embed_documents(list(texts))
|
||||
insert_dict[self.vector_field] = text_vector
|
||||
@@ -268,7 +289,7 @@ class MilvusStore(VectorStoreBase):
|
||||
return res.primary_keys
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""load document in vector database."""
|
||||
"""Load document in vector database."""
|
||||
batch_size = 500
|
||||
batched_list = [
|
||||
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
|
||||
@@ -280,6 +301,7 @@ class MilvusStore(VectorStoreBase):
|
||||
return doc_ids
|
||||
|
||||
def similar_search(self, text, topk) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results."""
|
||||
from pymilvus import Collection, DataType
|
||||
|
||||
"""similar_search in vector database."""
|
||||
@@ -409,12 +431,14 @@ class MilvusStore(VectorStoreBase):
|
||||
return ret[0], ret
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""Whether vector name exists."""
|
||||
from pymilvus import utility
|
||||
|
||||
"""is vector store name exist."""
|
||||
return utility.has_collection(self.collection_name)
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector name."""
|
||||
from pymilvus import utility
|
||||
|
||||
"""milvus delete collection name"""
|
||||
@@ -423,11 +447,12 @@ class MilvusStore(VectorStoreBase):
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""Delete vector by ids."""
|
||||
from pymilvus import Collection
|
||||
|
||||
self.col = Collection(self.collection_name)
|
||||
"""milvus delete vectors by ids"""
|
||||
logger.info(f"begin delete milvus ids...")
|
||||
# milvus delete vectors by ids
|
||||
logger.info(f"begin delete milvus ids: {ids}")
|
||||
delete_ids = ids.split(",")
|
||||
doc_ids = [int(doc_id) for doc_id in delete_ids]
|
||||
delet_expr = f"{self.primary_field} in {doc_ids}"
|
||||
|
@@ -1,9 +1,9 @@
|
||||
"""Postgres vector store."""
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
@@ -15,21 +15,26 @@ CFG = Config()
|
||||
class PGVectorConfig(VectorStoreConfig):
|
||||
"""PG vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
connection_string: str = Field(
|
||||
default=None,
|
||||
description="the connection string of vector store, if not set, will use the default connection string.",
|
||||
description="the connection string of vector store, if not set, will use the "
|
||||
"default connection string.",
|
||||
)
|
||||
|
||||
|
||||
class PGVectorStore(VectorStoreBase):
|
||||
"""`Postgres.PGVector` vector store.
|
||||
"""PG vector store.
|
||||
|
||||
To use this, you should have the ``pgvector`` python package installed.
|
||||
"""
|
||||
|
||||
def __init__(self, vector_store_config: PGVectorConfig) -> None:
|
||||
"""init pgvector storage"""
|
||||
|
||||
"""Create a PGVectorStore instance."""
|
||||
from langchain.vectorstores import PGVector
|
||||
|
||||
self.connection_string = vector_store_config.connection_string
|
||||
@@ -42,23 +47,43 @@ class PGVectorStore(VectorStoreBase):
|
||||
connection_string=self.connection_string,
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]:
|
||||
"""Perform similar search in PGVector."""
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
def vector_name_exists(self):
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Check if vector name exists."""
|
||||
try:
|
||||
self.vector_store_client.create_collection()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
logger.error(f"vector_name_exists error, {str(e)}")
|
||||
return False
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document to PGVector.
|
||||
|
||||
Args:
|
||||
chunks(List[Chunk]): document chunks.
|
||||
|
||||
Return:
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
|
||||
return self.vector_store_client.from_documents(lc_documents)
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector by name.
|
||||
|
||||
Args:
|
||||
vector_name(str): vector name.
|
||||
"""
|
||||
return self.vector_store_client.delete_collection()
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
def delete_by_ids(self, ids: str):
|
||||
"""Delete vector by ids.
|
||||
|
||||
Args:
|
||||
ids(str): vector ids, separated by comma.
|
||||
"""
|
||||
return self.vector_store_client.delete(ids)
|
||||
|
@@ -1,14 +1,13 @@
|
||||
"""Weaviate vector store."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
from .base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
@@ -17,6 +16,11 @@ CFG = Config()
|
||||
class WeaviateVectorConfig(VectorStoreConfig):
|
||||
"""Weaviate vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
weaviate_url: str = Field(
|
||||
default=os.getenv("WEAVIATE_URL", None),
|
||||
description="weaviate url address, if not set, will use the default url.",
|
||||
@@ -28,7 +32,7 @@ class WeaviateVectorConfig(VectorStoreConfig):
|
||||
|
||||
|
||||
class WeaviateStore(VectorStoreBase):
|
||||
"""Weaviate database"""
|
||||
"""Weaviate database."""
|
||||
|
||||
def __init__(self, vector_store_config: WeaviateVectorConfig) -> None:
|
||||
"""Initialize with Weaviate client."""
|
||||
@@ -49,8 +53,8 @@ class WeaviateStore(VectorStoreBase):
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
||||
def similar_search(self, text: str, topk: int) -> None:
|
||||
"""Perform similar search in Weaviate"""
|
||||
def similar_search(self, text: str, topk: int) -> List[Chunk]:
|
||||
"""Perform similar search in Weaviate."""
|
||||
logger.info("Weaviate similar search")
|
||||
# nearText = {
|
||||
# "concepts": [text],
|
||||
@@ -68,15 +72,16 @@ class WeaviateStore(VectorStoreBase):
|
||||
docs = []
|
||||
for r in res:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=r["page_content"],
|
||||
Chunk(
|
||||
content=r["page_content"],
|
||||
metadata={"metadata": r["metadata"]},
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Check if a vector name exists for a given class in Weaviate.
|
||||
"""Whether the vector name exists in Weaviate.
|
||||
|
||||
Returns:
|
||||
bool: True if the vector name exists, False otherwise.
|
||||
"""
|
||||
@@ -85,14 +90,15 @@ class WeaviateStore(VectorStoreBase):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
logger.error(f"vector_name_exists error, {str(e)}")
|
||||
return False
|
||||
|
||||
def _default_schema(self) -> None:
|
||||
"""
|
||||
Create the schema for Weaviate with a Document class containing metadata and text properties.
|
||||
"""
|
||||
"""Create default schema in Weaviate.
|
||||
|
||||
Create the schema for Weaviate with a Document class containing metadata and
|
||||
text properties.
|
||||
"""
|
||||
schema = {
|
||||
"classes": [
|
||||
{
|
||||
@@ -137,7 +143,7 @@ class WeaviateStore(VectorStoreBase):
|
||||
self.vector_store_client.schema.create(schema)
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load documents into Weaviate"""
|
||||
"""Load document to Weaviate."""
|
||||
logger.info("Weaviate load document")
|
||||
texts = [doc.content for doc in chunks]
|
||||
metadatas = [doc.metadata for doc in chunks]
|
||||
@@ -157,3 +163,5 @@ class WeaviateStore(VectorStoreBase):
|
||||
data_object=properties, class_name=self.vector_name
|
||||
)
|
||||
self.vector_store_client.batch.flush()
|
||||
# TODO: return ids
|
||||
return []
|
||||
|
@@ -11,4 +11,5 @@ isort==5.10.1
|
||||
pyupgrade==3.1.0
|
||||
types-requests
|
||||
types-beautifulsoup4
|
||||
types-Markdown
|
||||
types-Markdown
|
||||
types-tqdm
|
Reference in New Issue
Block a user