chore: Add pylint for storage (#1298)

This commit is contained in:
Fangyin Cheng
2024-03-15 15:42:46 +08:00
committed by GitHub
parent a207640ff2
commit 8897d6e8fd
50 changed files with 784 additions and 667 deletions

View File

@@ -8,8 +8,8 @@ follow_imports = skip
[mypy-dbgpt.datasource.*]
follow_imports = skip
[mypy-dbgpt.storage.*]
follow_imports = skip
# [mypy-dbgpt.storage.*]
# follow_imports = skip
[mypy-dbgpt.serve.*]
follow_imports = skip
@@ -57,4 +57,17 @@ ignore_missing_imports = True
[mypy-spacy.*]
ignore_missing_imports = True
follow_imports = skip
follow_imports = skip
# Storage
[mypy-msgpack.*]
ignore_missing_imports = True
[mypy-rocksdict.*]
ignore_missing_imports = True
[mypy-weaviate.*]
ignore_missing_imports = True
[mypy-pymilvus.*]
ignore_missing_imports = True

View File

@@ -50,6 +50,7 @@ fmt: setup ## Format Python code
# https://flake8.pycqa.org/en/latest/
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
$(VENV_BIN)/flake8 dbgpt/storage/
# TODO: More package checks with flake8.
.PHONY: fmt-check
@@ -60,8 +61,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
$(VENV_BIN)/blackdoc --check dbgpt examples
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
# $(VENV_BIN)/blackdoc --check dbgpt examples
# $(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/storage/
.PHONY: pre-commit
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
@@ -77,8 +77,10 @@ test-doc: $(VENV)/.testenv ## Run doctests
.PHONY: mypy
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/
# rag depends on core and storage, so we not need to check it again.
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
# TODO: More package checks with mypy.
.PHONY: coverage

View File

@@ -83,7 +83,7 @@ CREATE TABLE IF NOT EXISTS `chat_history`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
`message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',

View File

@@ -0,0 +1,4 @@
"""Old chat history module.
Just used by editor.
"""

View File

@@ -8,10 +8,10 @@ from dbgpt.core.interface.message import OnceConversation
class MemoryStoreType(Enum):
File = "file"
Memory = "memory"
# File = "file"
# Memory = "memory"
DB = "db"
DuckDb = "duckdb"
# DuckDb = "duckdb"
class BaseChatHistoryMemory(ABC):
@@ -24,18 +24,14 @@ class BaseChatHistoryMemory(ABC):
def messages(self) -> List[OnceConversation]: # type: ignore
"""Retrieve the messages from the local file"""
@abstractmethod
def create(self, user_name: str) -> None:
"""Append the message to the record in the local file"""
# @abstractmethod
# def create(self, user_name: str) -> None:
# """Append the message to the record in the local file"""
@abstractmethod
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""
# @abstractmethod
# def clear(self) -> None:
# """Clear session memory from the local file"""
@abstractmethod
def update(self, messages: List[OnceConversation]) -> None:
pass
@@ -45,14 +41,11 @@ class BaseChatHistoryMemory(ABC):
pass
@abstractmethod
def conv_info(self, conv_uid: Optional[str] = None) -> None:
pass
@abstractmethod
def get_messages(self) -> List[OnceConversation]:
def get_messages(self) -> List[Dict]:
pass
@staticmethod
@abstractmethod
def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:

View File

@@ -1,13 +1,16 @@
"""Module for chat history factory.
It will remove in the future, just support db store type now.
"""
import logging
from typing import Type
from dbgpt._private.config import Config
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
from .base import MemoryStoreType
from .base import BaseChatHistoryMemory, MemoryStoreType
# Import first for auto create table
from .store_type.meta_db_history import DbHistoryMemory
from .meta_db_history import DbHistoryMemory
# TODO remove global variable
CFG = Config()
@@ -20,13 +23,6 @@ class ChatHistory:
self.mem_store_class_map = {}
# Just support db store type after v0.4.6
# from .store_type.duckdb_history import DuckdbHistoryMemory
# from .store_type.file_history import FileHistoryMemory
# from .store_type.mem_history import MemHistoryMemory
# self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
# self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
# self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
@@ -53,24 +49,26 @@ class ChatHistory:
Raises:
ValueError: Invalid store type
"""
from .store_type.duckdb_history import DuckdbHistoryMemory
from .store_type.file_history import FileHistoryMemory
from .store_type.mem_history import MemHistoryMemory
if store_type == MemHistoryMemory.store_type:
if store_type == "memory":
logger.error(
"Not support memory store type, just support db store type now"
)
raise ValueError(f"Invalid store type: {store_type}")
if store_type == FileHistoryMemory.store_type:
if store_type == "file":
logger.error("Not support file store type, just support db store type now")
raise ValueError(f"Invalid store type: {store_type}")
if store_type == DuckdbHistoryMemory.store_type:
link1 = "https://docs.dbgpt.site/docs/faq/install#q6-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-sqlitel"
link2 = "https://docs.dbgpt.site/docs/faq/install#q7-how-to-migrate-meta-table-chat_history-and-connect_config-from-duckdb-to-mysql"
if store_type == "duckdb":
link1 = (
"https://docs.dbgpt.site/docs/latest/faq/install#q6-how-to-migrate-meta"
"-table-chat_history-and-connect_config-from-duckdb-to-sqlite"
)
link2 = (
"https://docs.dbgpt.site/docs/latest/faq/install/#q7-how-to-migrate-"
"meta-table-chat_history-and-connect_config-from-duckdb-to-mysql"
)
logger.error(
"Not support duckdb store type after v0.4.6, just support db store type now, "
f"you can migrate your message according to {link1} or {link2}"
"Not support duckdb store type after v0.4.6, just support db store "
f"type now, you can migrate your message according to {link1} or {link2}"
)
raise ValueError(f"Invalid store type: {store_type}")

View File

@@ -4,64 +4,76 @@ from typing import Dict, List, Optional
from dbgpt._private.config import Config
from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
from dbgpt.storage.chat_history.chat_history_db import ChatHistoryDao, ChatHistoryEntity
from .base import BaseChatHistoryMemory, MemoryStoreType
CFG = Config()
logger = logging.getLogger(__name__)
class DbHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.DB.value
"""Db history memory storage.
It is deprecated.
"""
store_type: str = MemoryStoreType.DB.value # type: ignore
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.chat_history_dao = ChatHistoryDao()
def messages(self) -> List[OnceConversation]:
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
chat_history: Optional[ChatHistoryEntity] = self.chat_history_dao.get_by_uid(
self.chat_seesion_id
)
if chat_history:
context = chat_history.messages
if context:
conversations: List[OnceConversation] = json.loads(context)
conversations: List[OnceConversation] = json.loads(
context # type: ignore
)
return conversations
return []
def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.chat_mode = chat_mode
chat_history.summary = summary
chat_history.user_name = user_name
self.chat_history_dao.raw_update(chat_history)
except Exception as e:
logger.error("init create conversation log error" + str(e))
# def create(self, chat_mode, summary: str, user_name: str) -> None:
# try:
# chat_history: ChatHistoryEntity = ChatHistoryEntity()
# chat_history.chat_mode = chat_mode
# chat_history.summary = summary
# chat_history.user_name = user_name
#
# self.chat_history_dao.raw_update(chat_history)
# except Exception as e:
# logger.error("init create conversation log error" + str(e))
#
def append(self, once_message: OnceConversation) -> None:
logger.debug(f"db history append: {once_message}")
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
chat_history: Optional[ChatHistoryEntity] = self.chat_history_dao.get_by_uid(
self.chat_seesion_id
)
conversations: List[OnceConversation] = []
conversations: List[Dict] = []
latest_user_message = once_message.get_latest_user_message()
summary = latest_user_message.content if latest_user_message else ""
if chat_history:
context = chat_history.messages
if context:
conversations = json.loads(context)
conversations = json.loads(context) # type: ignore
else:
chat_history.summary = once_message.get_latest_user_message().content
chat_history.summary = summary # type: ignore
else:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = once_message.user_name
chat_history.sys_code = once_message.sys_code
chat_history.summary = once_message.get_latest_user_message().content
chat_history = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id # type: ignore
chat_history.chat_mode = once_message.chat_mode # type: ignore
chat_history.user_name = once_message.user_name # type: ignore
chat_history.sys_code = once_message.sys_code # type: ignore
chat_history.summary = summary # type: ignore
conversations.append(_conversation_to_dict(once_message))
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
chat_history.messages = json.dumps( # type: ignore
conversations, ensure_ascii=False
)
self.chat_history_dao.raw_update(chat_history)
@@ -72,18 +84,13 @@ class DbHistoryMemory(BaseChatHistoryMemory):
def delete(self) -> bool:
self.chat_history_dao.raw_delete(self.chat_seesion_id)
return True
def conv_info(self, conv_uid: str = None) -> None:
logger.info("conv_info:{}", conv_uid)
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
return chat_history.__dict__
def get_messages(self) -> List[OnceConversation]:
# logger.info("get_messages:{}", self.chat_seesion_id)
def get_messages(self) -> List[Dict]:
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
if chat_history:
context = chat_history.messages
return json.loads(context)
return json.loads(context) # type: ignore
return []
@staticmethod

View File

@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import List
from typing import Dict, List
from fastapi import APIRouter, Body, Depends
@@ -23,9 +23,9 @@ from dbgpt.app.openapi.editor_view_model import (
)
from dbgpt.app.scene import ChatFactory
from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
from dbgpt.core.interface.message import OnceConversation
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
from ._chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter()
CFG = Config()
@@ -201,7 +201,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
history_messages: List[Dict] = history_mem.get_messages()
if history_messages:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)

View File

@@ -21,10 +21,9 @@ from dbgpt.core.awel import (
from dbgpt.core.operators import (
BufferedConversationMapperOperator,
HistoryPromptBuilderOperator,
LLMBranchOperator,
)
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
from dbgpt.storage.cache.operator import (
from dbgpt.storage.cache.operators import (
CachedModelOperator,
CachedModelStreamOperator,
CacheManager,

View File

@@ -33,7 +33,7 @@ async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIter
yield iter_data
class IteratorTrigger(Trigger):
class IteratorTrigger(Trigger[List[Tuple[Any, Any]]]):
"""Trigger for iterator data.
Trigger the dag with iterator data.
@@ -46,6 +46,7 @@ class IteratorTrigger(Trigger):
data: IterDataType,
parallel_num: int = 1,
streaming_call: bool = False,
show_progress: bool = True,
**kwargs
):
"""Create a IteratorTrigger.
@@ -60,6 +61,7 @@ class IteratorTrigger(Trigger):
self._iter_data = data
self._parallel_num = parallel_num
self._streaming_call = streaming_call
self._show_progress = show_progress
super().__init__(**kwargs)
async def trigger(
@@ -132,17 +134,27 @@ class IteratorTrigger(Trigger):
async def call_stream(call_data: Any):
async for out in await end_node.call_stream(call_data):
yield out
await dag._after_dag_end()
async def run_node(call_data: Any):
async def run_node(call_data: Any) -> Tuple[Any, Any]:
async with semaphore:
if streaming_call:
task_output = call_stream(call_data)
else:
task_output = await end_node.call(call_data)
await dag._after_dag_end()
return call_data, task_output
tasks = []
if self._show_progress:
from tqdm.asyncio import tqdm_asyncio
async_module = tqdm_asyncio
else:
async_module = asyncio # type: ignore
async for data in _to_async_iterator(self._iter_data, task_id):
tasks.append(run_node(data))
results = await asyncio.gather(*tasks)
results: List[Tuple[Any, Any]] = await async_module.gather(*tasks)
return results

View File

@@ -10,7 +10,7 @@ from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
from dbgpt.rag.embedding import Embeddings
from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings
class EmbeddingLoader:
@@ -47,7 +47,7 @@ class EmbeddingLoader:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
else:
from langchain.embeddings import HuggingFaceEmbeddings
from dbgpt.rag.embedding import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)

View File

@@ -1,7 +1,6 @@
from typing import List
from langchain.embeddings.base import Embeddings
from dbgpt.core import Embeddings
from dbgpt.model.cluster.manager_base import WorkerManager

View File

@@ -20,14 +20,8 @@ logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker):
def __init__(self) -> None:
try:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings
except ImportError as exc:
raise ImportError(
"Could not import langchain.embeddings.HuggingFaceEmbeddings python package. "
"Please install it with `pip install langchain`."
) from exc
from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings
self._embeddings_impl: Embeddings = None
self._model_params = None
self.model_name = None

View File

@@ -3,8 +3,7 @@
from enum import Enum
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.rag.chunk import Chunk, Document
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge

View File

@@ -411,6 +411,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
end_node = cast(BaseOperator, leaf_nodes[0])
async for output in _chat_with_dag_task(end_node, request, incremental):
yield output
await dag._after_dag_end()
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category

View File

@@ -0,0 +1,3 @@
"""Module of storage."""
from .schema import DBType # noqa: F401

View File

@@ -1,6 +1,7 @@
from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
"""Module for cache storage."""
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue # noqa: F401
from .manager import CacheManager, initialize_cache # noqa: F401
from .storage.base import MemoryCacheStorage # noqa: F401
__all__ = [
"LLMCacheKey",

View File

@@ -0,0 +1 @@
"""Embeddings cache."""

View File

@@ -1,21 +1,26 @@
"""Cache client for LLM."""
import hashlib
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput, Serializer
from dbgpt.core import ModelOutput
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
from dbgpt.model.base import ModelType
from dbgpt.storage.cache.manager import CacheManager
from .manager import CacheManager
@dataclass
class LLMCacheKeyData:
"""Cache key data for LLM."""
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
model_type: Optional[str] = ModelType.HF
# See dbgpt.model.base.ModelType
model_type: Optional[str] = "huggingface"
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@@ -23,12 +28,15 @@ CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
"""Cache value data for LLM."""
output: CacheOutputType
user: Optional[str] = None
_is_list: Optional[bool] = False
_is_list: bool = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
"""Create LLMCacheValueData object from dict."""
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
@@ -46,6 +54,7 @@ class LLMCacheValueData:
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
"""Convert to dict."""
output = self.output
is_list = False
if isinstance(output, list):
@@ -53,16 +62,18 @@ class LLMCacheValueData:
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list
output = output_list # type: ignore
else:
output = output.to_dict()
output = output.to_dict() # type: ignore
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
"""Return whether the output is a list."""
return self._is_list
def __str__(self) -> str:
"""Return string representation."""
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
@@ -70,74 +81,116 @@ class LLMCacheValueData:
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
"""Cache key for LLM."""
def __init__(self, **kwargs) -> None:
"""Create a new instance of LLMCacheKey."""
super().__init__()
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
"""Return the hash value of the object."""
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
"""Return the byte array of hash value.
Returns:
bytes: The byte array of hash value.
"""
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
"""Convert to dict."""
return asdict(self.config)
def get_value(self) -> LLMCacheKeyData:
"""Return the real object of current cache key."""
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
"""Cache value for LLM."""
def __init__(self, **kwargs) -> None:
"""Create a new instance of LLMCacheValue."""
super().__init__()
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
"""Convert to dict."""
return self.value.to_dict()
def get_value(self) -> LLMCacheValueData:
"""Return the underlying real value."""
return self.value
def __str__(self) -> str:
"""Return string representation."""
return f"value: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
"""Cache client for LLM."""
def __init__(self, cache_manager: CacheManager) -> None:
"""Create a new instance of LLMCacheClient."""
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
self,
key: LLMCacheKey, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> Optional[LLMCacheValue]:
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
"""Retrieve a value from the cache using the provided key.
Args:
key (LLMCacheKey): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[LLMCacheValue]: The value retrieved according to key. If cache key
not exist, return None.
"""
return cast(
LLMCacheValue,
await self._cache_manager.get(key, LLMCacheValue, cache_config),
)
async def set(
self,
key: LLMCacheKey,
value: LLMCacheValue,
key: LLMCacheKey, # type: ignore
value: LLMCacheValue, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
self,
key: LLMCacheKey, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> bool:
"""Check if a key exists in the cache."""
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey:
def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore
"""Create a cache key with params."""
key = LLMCacheKey(**kwargs)
key.set_serializer(self._cache_manager.serializer)
return key
def new_value(self, **kwargs) -> LLMCacheValue:
def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore
"""Create a cache value with params."""
value = LLMCacheValue(**kwargs)
value.set_serializer(self._cache_manager.serializer)
return value

View File

@@ -1,24 +1,31 @@
"""Cache manager."""
import logging
from abc import ABC, abstractmethod
from concurrent.futures import Executor
from typing import Optional, Type
from typing import Optional, Type, cast
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer
from dbgpt.core.interface.cache import K, V
from dbgpt.storage.cache.storage.base import CacheStorage
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from .storage.base import CacheStorage
logger = logging.getLogger(__name__)
class CacheManager(BaseComponent, ABC):
"""The cache manager interface."""
name = ComponentType.MODEL_CACHE_MANAGER
def __init__(self, system_app: SystemApp | None = None):
"""Create cache manager."""
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize cache manager."""
self.system_app = system_app
@abstractmethod
@@ -28,7 +35,7 @@ class CacheManager(BaseComponent, ABC):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache"""
"""Set cache with key."""
@abstractmethod
async def get(
@@ -36,27 +43,30 @@ class CacheManager(BaseComponent, ABC):
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
"""Get cache with key"""
) -> Optional[CacheValue[V]]:
"""Retrieve cache with key."""
@property
@abstractmethod
def serializer(self) -> Serializer:
"""Get cache serializer"""
"""Return serializer to serialize/deserialize cache value."""
class LocalCacheManager(CacheManager):
"""Local cache manager."""
def __init__(
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
) -> None:
"""Create local cache manager."""
super().__init__(system_app)
self._serializer = serializer
self._storage = storage
@property
def executor(self) -> Executor:
"""Return executor to submit task"""
self._executor = self.system_app.get_component(
"""Return executor."""
return self.system_app.get_component( # type: ignore
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
@@ -66,6 +76,7 @@ class LocalCacheManager(CacheManager):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache with key."""
if self._storage.support_async():
await self._storage.aset(key, value, cache_config)
else:
@@ -78,7 +89,8 @@ class LocalCacheManager(CacheManager):
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
) -> Optional[CacheValue[V]]:
"""Retrieve cache with key."""
if self._storage.support_async():
item_bytes = await self._storage.aget(key, cache_config)
else:
@@ -87,30 +99,42 @@ class LocalCacheManager(CacheManager):
)
if not item_bytes:
return None
return self._serializer.deserialize(item_bytes.value_data, cls)
return cast(
CacheValue[V], self._serializer.deserialize(item_bytes.value_data, cls)
)
@property
def serializer(self) -> Serializer:
"""Return serializer to serialize/deserialize cache value."""
return self._serializer
def initialize_cache(
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
):
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
"""Initialize cache manager.
Args:
system_app (SystemApp): The system app.
storage_type (str): The storage type.
max_memory_mb (int): The max memory in MB.
persist_dir (str): The persist directory.
"""
from dbgpt.util.serialization.json_serialization import JsonSerializer
cache_storage = None
from .storage.base import MemoryCacheStorage
if storage_type == "disk":
try:
from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage
from .storage.disk.disk_storage import DiskCacheStorage
cache_storage = DiskCacheStorage(
cache_storage: CacheStorage = DiskCacheStorage(
persist_dir, mem_table_buffer_mb=max_memory_mb
)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error "
f"message: {str(e)}"
)
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
else:

View File

@@ -1,5 +1,6 @@
"""Operators for processing model outputs with caching support."""
import logging
from typing import AsyncIterator, Dict, List, Union
from typing import AsyncIterator, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput, ModelRequest
from dbgpt.core.awel import (
@@ -10,7 +11,9 @@ from dbgpt.core.awel import (
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from .manager import CacheManager
logger = logging.getLogger(__name__)
@@ -26,15 +29,17 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
**kwargs: Additional keyword arguments.
Methods:
streamify: Processes a stream of inputs with cache support, yielding model outputs.
streamify: Processes a stream of inputs with cache support, yielding model
outputs.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelStreamOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]:
async def streamify(self, input_value: ModelRequest):
"""Process inputs as a stream with cache support and yield model outputs.
Args:
@@ -45,10 +50,13 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
llm_cache_value = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
for out in llm_cache_value.get_value().output:
yield out
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
outputs = cast(List[ModelOutput], llm_cache_value.get_value().output)
for out in outputs:
yield cast(ModelOutput, out)
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
@@ -63,6 +71,7 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
@@ -78,14 +87,18 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
llm_cache_value = await self._client.get(llm_cache_key)
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
logger.info(f"llm_cache_value: {llm_cache_value}")
return llm_cache_value.get_value().output
return cast(ModelOutput, llm_cache_value.get_value().output)
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
"""
A branch operator that decides whether to use cached data or to process data using the model.
"""Branch operator for model processing with cache support.
A branch operator that decides whether to use cached data or to process data using
the model.
Args:
cache_manager (CacheManager): The cache manager for managing cache operations.
@@ -101,6 +114,7 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
cache_task_name: str,
**kwargs,
):
"""Create a new instance of ModelCacheBranchOperator."""
super().__init__(branches=None, **kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
@@ -110,10 +124,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""Defines branch logic based on cache availability.
"""Branch logic based on cache availability.
Defines branch logic based on cache availability.
Returns:
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping
branch functions to task names.
"""
async def check_cache_true(input_value: ModelRequest) -> bool:
@@ -124,12 +141,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: "
f"{cache_value}"
)
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
)
return True if cache_value else False
return bool(cache_value)
async def check_cache_false(input_value: ModelRequest):
# Inverse of check_cache_true
@@ -152,22 +170,25 @@ class ModelStreamSaveCacheOperator(
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelStreamSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[ModelOutput]:
"""Transforms the input stream by saving the outputs to cache.
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
"""Save the stream of model outputs to cache.
Transforms the input stream by saving the outputs to cache.
Args:
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model
outputs.
Returns:
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are
saved to cache.
"""
llm_cache_key: LLMCacheKey = None
llm_cache_key: Optional[LLMCacheKey] = None
outputs = []
async for out in input_value:
if not llm_cache_key:
@@ -190,12 +211,13 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Saves a single model output to cache and returns it.
"""Save model output to cache.
Args:
input_value (ModelOutput): The output from the model to be cached.
@@ -213,7 +235,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
"""Parses and extracts relevant fields from input to form a cache key dictionary.
"""Parse and extract relevant fields from input to form a cache key dictionary.
Args:
input_value (Dict): The input dictionary containing model and prompt parameters.

View File

@@ -0,0 +1 @@
"""Module for protocol."""

View File

@@ -0,0 +1 @@
"""Module for cache storage implementation."""

View File

@@ -1,3 +1,4 @@
"""Base cache storage class."""
import logging
from abc import ABC, abstractmethod
from collections import OrderedDict
@@ -22,8 +23,7 @@ logger = logging.getLogger(__name__)
@dataclass
class StorageItem:
"""
A class representing a storage item.
"""A class representing a storage item.
This class encapsulates data related to a storage item, such as its length,
the hash of the key, and the data for both the key and value.
@@ -44,6 +44,7 @@ class StorageItem:
def build_from(
key_hash: bytes, key_data: bytes, value_data: bytes
) -> "StorageItem":
"""Build a StorageItem from the provided key and value data."""
length = (
32
+ _get_object_bytes(key_hash)
@@ -56,6 +57,7 @@ class StorageItem:
@staticmethod
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
"""Build a StorageItem from the provided key and value."""
key_hash = key.get_hash_bytes()
key_data = key.serialize()
value_data = value.serialize()
@@ -105,6 +107,8 @@ class StorageItem:
class CacheStorage(ABC):
"""Base class for cache storage."""
@abstractmethod
def check_config(
self,
@@ -122,6 +126,7 @@ class CacheStorage(ABC):
"""
def support_async(self) -> bool:
"""Check whether the storage support async operation."""
return False
@abstractmethod
@@ -135,7 +140,8 @@ class CacheStorage(ABC):
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
Optional[StorageItem]: The storage item retrieved according to key. If
cache key not exist, return None.
"""
async def aget(
@@ -148,7 +154,8 @@ class CacheStorage(ABC):
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
Optional[StorageItem]: The storage item of bytes retrieved according to
key. If cache key not exist, return None.
"""
raise NotImplementedError
@@ -184,8 +191,11 @@ class CacheStorage(ABC):
class MemoryCacheStorage(CacheStorage):
"""A simple in-memory cache storage implementation."""
def __init__(self, max_memory_mb: int = 256):
self.cache = OrderedDict()
"""Create a new instance of MemoryCacheStorage."""
self.cache: OrderedDict = OrderedDict()
self.max_memory = max_memory_mb * 1024 * 1024
self.current_memory_usage = 0
@@ -194,6 +204,7 @@ class MemoryCacheStorage(CacheStorage):
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal."""
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
@@ -208,10 +219,11 @@ class MemoryCacheStorage(CacheStorage):
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key."""
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = hash(key)
item: StorageItem = self.cache.get(key_hash)
item: Optional[StorageItem] = self.cache.get(key_hash)
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
if not item:
@@ -226,6 +238,7 @@ class MemoryCacheStorage(CacheStorage):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
key_hash = hash(key)
item = StorageItem.build_from_kv(key, value)
# Calculate memory size of the new entry
@@ -242,6 +255,7 @@ class MemoryCacheStorage(CacheStorage):
def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
"""Check if the key exists in the cache."""
return self.get(key, cache_config) is not None
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):

View File

@@ -0,0 +1 @@
"""Disk cache storage implementation."""

View File

@@ -1,3 +1,7 @@
"""Disk storage for cache.
Implement the cache storage using rocksdb.
"""
import logging
from typing import Optional
@@ -11,14 +15,14 @@ from dbgpt.core.interface.cache import (
RetrievalPolicy,
V,
)
from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem
from ..base import CacheStorage, StorageItem
logger = logging.getLogger(__name__)
def db_options(
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
):
def db_options(mem_table_buffer_mb: int = 256, background_threads: int = 2):
"""Create rocksdb options."""
opt = Options()
# create table
opt.create_if_missing(True)
@@ -42,9 +46,10 @@ def db_options(
class DiskCacheStorage(CacheStorage):
def __init__(
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
) -> None:
"""Disk cache storage using rocksdb."""
def __init__(self, persist_dir: str, mem_table_buffer_mb: int = 256) -> None:
"""Create a new instance of DiskCacheStorage."""
super().__init__()
self.db: Rdict = Rdict(
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
@@ -55,6 +60,7 @@ class DiskCacheStorage(CacheStorage):
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal."""
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
@@ -69,6 +75,7 @@ class DiskCacheStorage(CacheStorage):
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key."""
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
@@ -86,6 +93,7 @@ class DiskCacheStorage(CacheStorage):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
item = StorageItem.build_from_kv(key, value)
key_hash = item.key_hash
self.db[key_hash] = item.serialize()

View File

@@ -1,5 +1,3 @@
import pytest
from dbgpt.util.memory_utils import _get_object_bytes
from ..base import StorageItem

View File

@@ -1 +1,19 @@
"""Module of chat history."""
from .chat_history_db import ( # noqa: F401
ChatHistoryDao,
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
from .storage_adapter import ( # noqa: F401
DBMessageStorageItemAdapter,
DBStorageConversationItemAdapter,
)
__ALL__ = [
"ChatHistoryEntity",
"ChatHistoryMessageEntity",
"ChatHistoryDao",
"DBStorageConversationItemAdapter",
"DBMessageStorageItemAdapter",
]

View File

@@ -1,12 +1,15 @@
"""Chat history database model."""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
from ..metadata import BaseDao, Model
class ChatHistoryEntity(Model):
"""Chat history entity."""
__tablename__ = "chat_history"
__table_args__ = (UniqueConstraint("conv_uid", name="uk_conv_uid"),)
id = Column(
@@ -14,13 +17,16 @@ class ChatHistoryEntity(Model):
)
conv_uid = Column(
String(255),
# Change from False to True, the alembic migration will fail, so we use UniqueConstraint to replace it
# Change from False to True, the alembic migration will fail, so we use
# UniqueConstraint to replace it
unique=False,
nullable=False,
comment="Conversation record unique id",
)
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
summary = Column(String(255), nullable=False, comment="Conversation record summary")
summary = Column(
Text(length=2**31 - 1), nullable=False, comment="Conversation record summary"
)
user_name = Column(String(255), nullable=True, comment="interlocutor")
messages = Column(
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
@@ -38,6 +44,8 @@ class ChatHistoryEntity(Model):
class ChatHistoryMessageEntity(Model):
"""Chat history message entity."""
__tablename__ = "chat_history_message"
__table_args__ = (
UniqueConstraint("conv_uid", "index", name="uk_conversation_message"),
@@ -61,9 +69,12 @@ class ChatHistoryMessageEntity(Model):
class ChatHistoryDao(BaseDao):
"""Chat history dao."""
def list_last_20(
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
):
"""Retrieve the last 20 chat history records."""
session = self.get_raw_session()
chat_history = session.query(ChatHistoryEntity)
if user_name:
@@ -78,6 +89,7 @@ class ChatHistoryDao(BaseDao):
return result
def raw_update(self, entity: ChatHistoryEntity):
"""Update the chat history record."""
session = self.get_raw_session()
try:
updated = session.merge(entity)
@@ -87,6 +99,7 @@ class ChatHistoryDao(BaseDao):
session.close()
def update_message_by_uid(self, message: str, conv_uid: str):
"""Update the chat history record."""
session = self.get_raw_session()
try:
chat_history = session.query(ChatHistoryEntity)
@@ -97,7 +110,8 @@ class ChatHistoryDao(BaseDao):
finally:
session.close()
def raw_delete(self, conv_uid: int):
def raw_delete(self, conv_uid: str):
"""Delete the chat history record."""
if conv_uid is None:
raise Exception("conv_uid is None")
with self.session() as session:
@@ -106,5 +120,6 @@ class ChatHistoryDao(BaseDao):
chat_history.delete()
def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]:
"""Retrieve the chat history record by conv_uid."""
with self.session(commit=False) as session:
return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first()

View File

@@ -1,5 +1,7 @@
"""Adapter for chat history storage."""
import json
from typing import Dict, List, Type
from typing import Dict, List, Optional, Type
from sqlalchemy.orm import Session
@@ -20,16 +22,23 @@ from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity
class DBStorageConversationItemAdapter(
StorageItemAdapter[StorageConversation, ChatHistoryEntity]
):
"""Adapter for chat history storage."""
def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity:
"""Convert to storage format."""
message_ids = ",".join(item.message_ids)
messages = None
if not item.save_message_independent and item.messages:
message_dict_list = [_conversation_to_dict(item)]
messages = json.dumps(message_dict_list, ensure_ascii=False)
summary = item.summary
latest_user_message = item.get_latest_user_message()
if not summary and latest_user_message is not None:
summary = latest_user_message.content
return ChatHistoryEntity(
conv_uid=item.conv_uid,
chat_mode=item.chat_mode,
summary=item.summary or item.get_latest_user_message().content,
summary=summary,
user_name=item.user_name,
# We not save messages to chat_history table in new design
messages=messages,
@@ -38,25 +47,25 @@ class DBStorageConversationItemAdapter(
)
def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation:
"""Convert from storage format."""
message_ids = model.message_ids.split(",") if model.message_ids else []
old_conversations: List[Dict] = (
json.loads(model.messages) if model.messages else []
json.loads(model.messages) if model.messages else [] # type: ignore
)
old_messages = []
save_message_independent = True
if old_conversations:
# Load old messages from old conversations, in old design, we save messages to chat_history table
# Load old messages from old conversations, in old design, we save messages
# to chat_history table
save_message_independent = False
old_messages: List[BaseMessage] = _parse_old_conversations(
old_conversations
)
old_messages = _parse_old_conversations(old_conversations)
return StorageConversation(
conv_uid=model.conv_uid,
chat_mode=model.chat_mode,
summary=model.summary,
user_name=model.user_name,
conv_uid=model.conv_uid, # type: ignore
chat_mode=model.chat_mode, # type: ignore
summary=model.summary, # type: ignore
user_name=model.user_name, # type: ignore
message_ids=message_ids,
sys_code=model.sys_code,
sys_code=model.sys_code, # type: ignore
save_message_independent=save_message_independent,
messages=old_messages,
)
@@ -64,10 +73,11 @@ class DBStorageConversationItemAdapter(
def get_query_for_identifier(
self,
storage_format: Type[ChatHistoryEntity],
resource_id: ConversationIdentifier,
resource_id: ConversationIdentifier, # type: ignore
**kwargs,
):
session: Session = kwargs.get("session")
"""Get query for identifier."""
session: Optional[Session] = kwargs.get("session")
if session is None:
raise Exception("session is None")
return session.query(ChatHistoryEntity).filter(
@@ -78,7 +88,10 @@ class DBStorageConversationItemAdapter(
class DBMessageStorageItemAdapter(
StorageItemAdapter[MessageStorageItem, ChatHistoryMessageEntity]
):
"""Adapter for chat history message storage."""
def to_storage_format(self, item: MessageStorageItem) -> ChatHistoryMessageEntity:
"""Convert to storage format."""
round_index = item.message_detail.get("round_index", 0)
message_detail = json.dumps(item.message_detail, ensure_ascii=False)
return ChatHistoryMessageEntity(
@@ -91,22 +104,26 @@ class DBMessageStorageItemAdapter(
def from_storage_format(
self, model: ChatHistoryMessageEntity
) -> MessageStorageItem:
"""Convert from storage format."""
message_detail = (
json.loads(model.message_detail) if model.message_detail else {}
json.loads(model.message_detail) # type: ignore
if model.message_detail
else {}
)
return MessageStorageItem(
conv_uid=model.conv_uid,
index=model.index,
conv_uid=model.conv_uid, # type: ignore
index=model.index, # type: ignore
message_detail=message_detail,
)
def get_query_for_identifier(
self,
storage_format: Type[ChatHistoryMessageEntity],
resource_id: MessageIdentifier,
resource_id: MessageIdentifier, # type: ignore
**kwargs,
):
session: Session = kwargs.get("session")
"""Get query for identifier."""
session: Optional[Session] = kwargs.get("session")
if session is None:
raise Exception("session is None")
return session.query(ChatHistoryMessageEntity).filter(

View File

@@ -1,182 +0,0 @@
import json
import os
from typing import Dict, List, Optional
import duckdb
from dbgpt._private.config import Config
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.core.interface.message import OnceConversation, _conversation_to_dict
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
from ..base import MemoryStoreType
default_db_path = os.path.join(PILOT_PATH, "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = "chat_history"
CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.DuckDb.value
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
self.connect = duckdb.connect(duckdb_path)
self.__init_chat_history_tables()
def __init_chat_history_tables(self):
# 检查表是否存在
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)"
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
content = cursor.fetchone()
if content:
return content[0]
else:
return None
def messages(self) -> List[OnceConversation]:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
if context:
conversations: List[OnceConversation] = json.loads(context)
return conversations
return []
def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, "", ""],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None:
context = self.__get_messages_by_conv_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = []
if context:
conversations = json.loads(context)
conversations.append(_conversation_to_dict(once_message))
cursor = self.connect.cursor()
if context:
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
)
else:
cursor.execute(
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[
self.chat_seesion_id,
once_message.chat_mode,
once_message.get_latest_user_message().content,
once_message.user_name,
once_message.sys_code,
json.dumps(conversations, ensure_ascii=False),
],
)
cursor.commit()
self.connect.commit()
def update(self, messages: List[OnceConversation]) -> None:
cursor = self.connect.cursor()
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(messages, ensure_ascii=False), self.chat_seesion_id],
)
cursor.commit()
self.connect.commit()
def clear(self) -> None:
cursor = self.connect.cursor()
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
self.connect.commit()
def delete(self) -> bool:
cursor = self.connect.cursor()
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
return True
def conv_info(self, conv_uid: str = None) -> None:
cursor = self.connect.cursor()
cursor.execute(
"SELECT * FROM chat_history where conv_uid=? ",
[conv_uid],
)
# 获取查询结果字段名
fields = [field[0] for field in cursor.description]
for row in cursor.fetchone():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
return row_dict
return {}
def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone()
if context:
if context[0]:
return json.loads(context[0])
return None
@staticmethod
def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
query = "SELECT * FROM chat_history"
params = []
conditions = []
if user_name:
conditions.append("user_name = ?")
params.append(user_name)
if sys_code:
conditions.append("sys_code = ?")
params.append(sys_code)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY id DESC LIMIT 20"
cursor.execute(query, params)
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []

View File

@@ -1,50 +0,0 @@
import datetime
import json
import os
from pathlib import Path
from typing import List
from dbgpt._private.config import Config
from dbgpt.core.interface.message import (
OnceConversation,
_conversation_from_dict,
_conversations_to_dict,
)
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.File.value
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
date_string = now.strftime("%Y%m%d")
path: str = f"{CFG.message_dir}/{date_string}"
os.makedirs(path, exist_ok=True)
dir_path = Path(path)
self.file_path = Path(dir_path / f"{chat_session_id}.json")
if not self.file_path.exists():
self.file_path.touch()
self.file_path.write_text(json.dumps([]))
def messages(self) -> List[OnceConversation]:
items = json.loads(self.file_path.read_text())
history: List[OnceConversation] = []
for onece in items:
messages = _conversation_from_dict(onece)
history.append(messages)
return history
def append(self, once_message: OnceConversation) -> None:
historys = self.messages()
historys.append(once_message)
self.file_path.write_text(
json.dumps(_conversations_to_dict(historys), ensure_ascii=False, indent=4),
encoding="UTF-8",
)
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))

View File

@@ -1,27 +0,0 @@
from typing import List
from dbgpt._private.config import Config
from dbgpt.core.interface.message import OnceConversation
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory, MemoryStoreType
from dbgpt.util.custom_data_structure import FixedSizeDict
CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.Memory.value
histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []})
def messages(self) -> List[OnceConversation]:
return self.histroies_map.get(self.chat_seesion_id)
def append(self, once_message: OnceConversation) -> None:
self.histroies_map.get(self.chat_seesion_id).append(once_message)
def clear(self) -> None:
self.histroies_map.pop(self.chat_seesion_id)

View File

@@ -1,6 +1,7 @@
from dbgpt.storage.metadata._base_dao import BaseDao
from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory
from dbgpt.storage.metadata.db_manager import (
"""Module for handling metadata storage."""
from dbgpt.storage.metadata._base_dao import BaseDao # noqa: F401
from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory # noqa: F401
from dbgpt.storage.metadata.db_manager import ( # noqa: F401
BaseModel,
DatabaseManager,
Model,

View File

@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
from sqlalchemy.orm.session import Session
@@ -44,6 +44,7 @@ class BaseDao(Generic[T, REQ, RES]):
self,
db_manager: Optional[DatabaseManager] = None,
) -> None:
"""Create a BaseDao instance."""
self._db_manager = db_manager or db
def get_raw_session(self) -> Session:
@@ -52,9 +53,7 @@ class BaseDao(Generic[T, REQ, RES]):
Your should commit or rollback the session manually.
We suggest you use :meth:`session` instead.
Example:
.. code-block:: python
user = User(name="Edward Snowden")
@@ -63,13 +62,14 @@ class BaseDao(Generic[T, REQ, RES]):
session.commit()
session.close()
"""
return self._db_manager._session()
return self._db_manager._session() # type: ignore
@contextmanager
def session(self, commit: Optional[bool] = True) -> Session:
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations.
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
If raise an exception, the session will be roll back automatically, otherwise
it will be committed.
Example:
.. code-block:: python
@@ -78,7 +78,8 @@ class BaseDao(Generic[T, REQ, RES]):
session.query(User).filter(User.name == "Edward Snowden").first()
Args:
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
commit (Optional[bool], optional): Whether to commit the session. Defaults
to True.
Returns:
Session: A session object.
@@ -147,7 +148,8 @@ class BaseDao(Generic[T, REQ, RES]):
session.add(entry)
req = self.to_request(entry)
session.commit()
return self.get_one(req)
res = self.get_one(req)
return res # type: ignore
def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES:
"""Update an entity object.
@@ -163,11 +165,14 @@ class BaseDao(Generic[T, REQ, RES]):
entry = query.first()
if entry is None:
raise Exception("Invalid request")
for key, value in update_request.dict().items():
for key, value in update_request.dict().items(): # type: ignore
if value is not None:
setattr(entry, key, value)
session.merge(entry)
return self.get_one(self.to_request(entry))
res = self.get_one(self.to_request(entry))
if not res:
raise Exception("Update failed")
return res
def delete(self, query_request: QUERY_SPEC) -> None:
"""Delete an entity object.
@@ -179,7 +184,8 @@ class BaseDao(Generic[T, REQ, RES]):
result_list = self._get_entity_list(session, query_request)
if len(result_list) != 1:
raise ValueError(
f"Delete request should return one result, but got {len(result_list)}"
f"Delete request should return one result, but got "
f"{len(result_list)}"
)
session.delete(result_list[0])
@@ -241,11 +247,11 @@ class BaseDao(Generic[T, REQ, RES]):
query = self._create_query_object(session, query_request)
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
items = [self.to_response(item) for item in items]
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
return PaginationResult(
items=items,
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
@@ -273,4 +279,4 @@ class BaseDao(Generic[T, REQ, RES]):
if isinstance(value, (list, tuple, dict, set)):
continue
query = query.filter(getattr(model_cls, key) == value)
return query
return query # type: ignore

View File

@@ -1,19 +1,26 @@
"""UnifiedDBManagerFactory is a factory class to create a DatabaseManager instance."""
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .db_manager import DatabaseManager
class UnifiedDBManagerFactory(BaseComponent):
"""UnfiedDBManagerFactory class."""
name = ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY
"""The name of the factory."""
def __init__(self, system_app: SystemApp, db_manager: DatabaseManager):
"""Create a UnifiedDBManagerFactory instance."""
super().__init__(system_app)
self._db_manager = db_manager
def init_app(self, system_app: SystemApp):
"""Initialize the factory with the system app."""
pass
def create(self) -> DatabaseManager:
"""Create a DatabaseManager instance."""
if not self._db_manager:
raise RuntimeError("db_manager is not initialized")
if not self._db_manager.is_initialized:

View File

@@ -1,8 +1,9 @@
"""The database manager."""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union
from typing import ClassVar, Dict, Generic, Iterator, Optional, Type, TypeVar, Union
from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm
from sqlalchemy.orm import (
@@ -21,19 +22,20 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound="BaseModel")
class _QueryObject:
"""The query object."""
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
return model_cls.query_class(
model_cls, session=model_cls.__db_manager__._session()
)
# class _QueryObject:
# """The query object."""
#
# def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
# return model_cls.query_class(
# model_cls, session=model_cls.__db_manager__._session()
# )
#
class BaseQuery(orm.Query):
def paginate_query(
self, page: Optional[int] = 1, per_page: Optional[int] = 20
) -> PaginationResult:
"""Base query class."""
def paginate_query(self, page: int = 1, per_page: int = 20) -> PaginationResult:
"""Paginate the query.
Example:
@@ -56,10 +58,10 @@ class BaseQuery(orm.Query):
)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
per_page (Optional[int], optional): The number of items per page. Defaults
to 20.
Returns:
PaginationResult: The pagination result.
"""
@@ -100,7 +102,6 @@ class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
@@ -161,6 +162,7 @@ class DatabaseManager:
Query = BaseQuery
def __init__(self):
"""Create a DatabaseManager."""
self._db_url = None
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
self._engine: Optional[Engine] = None
@@ -169,12 +171,12 @@ class DatabaseManager:
@property
def Model(self) -> _Model:
"""Get the declarative base."""
return self._base
return self._base # type: ignore
@property
def metadata(self) -> MetaData:
"""Get the metadata."""
return self.Model.metadata
return self.Model.metadata # type: ignore
@property
def engine(self):
@@ -183,11 +185,11 @@ class DatabaseManager:
@property
def is_initialized(self) -> bool:
"""Whether the database manager is initialized.""" ""
"""Whether the database manager is initialized."""
return self._engine is not None and self._session is not None
@contextmanager
def session(self, commit: Optional[bool] = True) -> Session:
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
"""Get the session with context manager.
This context manager handles the lifecycle of a SQLAlchemy session.
@@ -199,7 +201,6 @@ class DatabaseManager:
read and write operations.
Examples:
.. code-block:: python
# For write operations (insert, update, delete):
@@ -211,7 +212,8 @@ class DatabaseManager:
# For read-only operations:
with db.session(commit=False) as session:
user = session.query(User).filter_by(name="John Doe").first()
# session.commit() is NOT called, as it's unnecessary for read operations
# session.commit() is NOT called, as it's unnecessary for read
# operations
Args:
commit (Optional[bool], optional): Whether to commit the session.
@@ -237,16 +239,15 @@ class DatabaseManager:
with the ORM object are complete.
c. Re-bind the instance to a new session if further interaction
is required after the session is closed.
"""
if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
session = self._session() # type: ignore
try:
yield session
if commit:
session.commit()
except:
except Exception:
session.rollback()
raise
finally:
@@ -266,10 +267,10 @@ class DatabaseManager:
if not isinstance(model, DeclarativeMeta):
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query
model.query_class = self.Query # type: ignore
# model.query = _QueryObject()
model.__db_manager__ = self
return model
model.__db_manager__ = self # type: ignore
return model # type: ignore
def init_db(
self,
@@ -284,11 +285,14 @@ class DatabaseManager:
Args:
db_url (Union[str, URL]): The database url.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
override_query_class (Optional[bool], optional): Whether to override the
query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults
to None.
"""
if session_options is None:
session_options = {}
@@ -309,8 +313,8 @@ class DatabaseManager:
session_options.setdefault("query_cls", self.Query)
session_factory = sessionmaker(bind=self._engine, **session_options)
# self._session = scoped_session(session_factory)
self._session = session_factory
self._base.metadata.bind = self._engine
self._session = session_factory # type: ignore
self._base.metadata.bind = self._engine # type: ignore
def init_default_db(
self,
@@ -333,24 +337,28 @@ class DatabaseManager:
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
"""
if not engine_args:
engine_args = {}
# Pool class
engine_args["poolclass"] = QueuePool
# The number of connections to keep open inside the connection pool.
engine_args["pool_size"] = 10
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
# pool_size).
engine_args["max_overflow"] = 20
# The number of seconds to wait before giving up on getting a connection from the pool.
engine_args["pool_timeout"] = 30
# Recycle the connection if it has been idle for this many seconds.
engine_args["pool_recycle"] = 3600
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
engine_args["pool_pre_ping"] = True
engine_args = {
# Pool class
"poolclass": QueuePool,
# The number of connections to keep open inside the connection pool.
"pool_size": 10,
# The maximum overflow size of the pool when the number of connections
# be used in the pool is exceeded(pool_size).
"max_overflow": 20,
# The number of seconds to wait before giving up on getting a connection
# from the pool.
"pool_timeout": 30,
# Recycle the connection if it has been idle for this many seconds.
"pool_recycle": 3600,
# Enable the connection poolpre-ping” feature that tests connections
# for liveness upon each checkout.
"pool_pre_ping": True,
}
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
def create_all(self):
"""Create all tables."""
self.Model.metadata.create_all(self._engine)
@staticmethod
@@ -364,9 +372,7 @@ class DatabaseManager:
"""Build the database manager from the db_url_or_db.
Examples:
Build from the database url.
.. code-block:: python
from dbgpt.storage.metadata import DatabaseManager
@@ -389,16 +395,19 @@ class DatabaseManager:
print(User.query.filter(User.name == "test").all())
Args:
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the
database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
override_query_class (Optional[bool], optional): Whether to override the
query class. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
if isinstance(db_url_or_db, (str, URL)):
db_manager = DatabaseManager()
db_manager.init_db(
db_url_or_db, engine_args, base, query_class, override_query_class
@@ -408,7 +417,8 @@ class DatabaseManager:
return db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
f"db_url_or_db should be either url or a DatabaseManager, got "
f"{type(db_url_or_db)}"
)
@@ -422,7 +432,6 @@ Examples:
>>> with db.session() as session:
... session.query(...)
...
>>> from dbgpt.storage.metadata import db, Model
>>> from urllib.parse import quote_plus as urlquote, quote
>>> db_name = "dbgpt"
@@ -430,7 +439,10 @@ Examples:
>>> db_port = 3306
>>> user = "root"
>>> password = "123456"
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
>>> url = (
... f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}"
... f":{str(db_port)}/{db_name}"
... )
>>> engine_args = {
... "pool_size": 10,
... "max_overflow": 20,
@@ -460,18 +472,21 @@ class BaseCRUDMixin(Generic[T]):
@classmethod
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls.__db_manager__
return cls.__db_manager__ # type: ignore
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
"""Whether the model is abstract."""
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
"""Create a model."""
class CRUDMixin(BaseCRUDMixin[T], Generic[T]): # type: ignore
"""Mixin that adds convenience methods for CRUD."""
_db_manager: DatabaseManager = db_manager
@@ -485,7 +500,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
"""Get the database manager."""
return cls._db_manager
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): # type: ignore
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
@@ -493,7 +508,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
return _NewModel
Model = create_model(db)
Model: Type = create_model(db)
def initialize_db(
@@ -511,8 +526,10 @@ def initialize_db(
db_name (str): The database name.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the
database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to
None.
Returns:
DatabaseManager: The database manager.
"""

View File

@@ -1,5 +1,6 @@
"""Database storage implementation using SQLAlchemy."""
from contextlib import contextmanager
from typing import Dict, List, Optional, Type, Union
from typing import Dict, Iterator, List, Optional, Type, Union
from sqlalchemy import URL
from sqlalchemy.orm import DeclarativeMeta, Session
@@ -17,8 +18,8 @@ from .db_manager import BaseModel, BaseQuery, DatabaseManager
def _copy_public_properties(src: BaseModel, dest: BaseModel):
"""Simple copy public properties from src to dest"""
for column in src.__table__.columns:
"""Copy public properties from src to dest."""
for column in src.__table__.columns: # type: ignore
if column.name != "id":
value = getattr(src, column.name)
if value is not None:
@@ -26,6 +27,8 @@ def _copy_public_properties(src: BaseModel, dest: BaseModel):
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
"""Database storage implementation using SQLAlchemy."""
def __init__(
self,
db_url_or_db: Union[str, URL, DatabaseManager],
@@ -36,6 +39,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
):
"""Create a SQLAlchemyStorage instance."""
super().__init__(serializer=serializer, adapter=adapter)
self.db_manager = DatabaseManager.build_from(
db_url_or_db, engine_args, base, query_class
@@ -43,16 +47,19 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
self._model_class = model_class
@contextmanager
def session(self) -> Session:
def session(self) -> Iterator[Session]:
"""Return a session."""
with self.db_manager.session() as session:
yield session
def save(self, data: T) -> None:
"""Save data to the storage."""
with self.session() as session:
model_instance = self.adapter.to_storage_format(data)
session.add(model_instance)
def update(self, data: T) -> None:
"""Update data in the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, data.identifier, session=session
@@ -66,6 +73,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
return
def save_or_update(self, data: T) -> None:
"""Save or update data in the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, data.identifier, session=session
@@ -79,6 +87,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
self.save(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
"""Load data by identifier from the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
@@ -89,6 +98,7 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
return None
def delete(self, resource_id: ResourceIdentifier) -> None:
"""Delete data by identifier from the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session

View File

@@ -1,14 +1,21 @@
"""Database information class and database type enumeration."""
import os
from enum import Enum
from typing import Optional
class DbInfo:
"""Database information class."""
def __init__(self, name, is_file_db: bool = False):
"""Create a new instance of DbInfo."""
self.name = name
self.is_file_db = is_file_db
class DBType(Enum):
"""Database type enumeration."""
Mysql = DbInfo("mysql")
OCeanBase = DbInfo("oceanbase")
DuckDb = DbInfo("duckdb", True)
@@ -22,14 +29,24 @@ class DBType(Enum):
Doris = DbInfo("doris")
Hive = DbInfo("hive")
def value(self):
def value(self) -> str:
"""Return the name of the database type."""
return self._value_.name
def is_file_db(self):
def is_file_db(self) -> bool:
"""Return whether the database is a file database."""
return self._value_.is_file_db
@staticmethod
def of_db_type(db_type: str):
def of_db_type(db_type: str) -> Optional["DBType"]:
"""Return the database type of the given name.
Args:
db_type (str): The name of the database type.
Returns:
Optional[DBType]: The database type of the given name.
"""
for item in DBType:
if item.value() == db_type:
return item
@@ -37,7 +54,7 @@ class DBType(Enum):
@staticmethod
def parse_file_db_name_from_path(db_type: str, local_db_path: str):
"""Parse out the database name of the embedded database from the file path"""
"""Parse out the database name of the embedded database from the file path."""
base_name = os.path.basename(local_db_path)
db_name = os.path.splitext(base_name)[0]
if "." in db_name:

View File

@@ -1,3 +1,4 @@
"""Vector Store Module."""
from typing import Any

View File

@@ -1,12 +1,13 @@
"""Vector store base class."""
import logging
import math
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field
from typing import List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import Embeddings
from dbgpt.rag.chunk import Chunk
logger = logging.getLogger(__name__)
@@ -15,6 +16,11 @@ logger = logging.getLogger(__name__)
class VectorStoreConfig(BaseModel):
"""Vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
name: str = Field(
default="dbgpt_collection",
description="The name of vector store, if not set, will use the default name.",
@@ -28,7 +34,7 @@ class VectorStoreConfig(BaseModel):
description="The password of vector store, if not set, will use the default "
"password.",
)
embedding_fn: Optional[Any] = Field(
embedding_fn: Optional[Embeddings] = Field(
default=None,
description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
@@ -47,27 +53,31 @@ class VectorStoreConfig(BaseModel):
class VectorStoreBase(ABC):
"""base class for vector store database"""
"""Vector store base class."""
@abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""load document in vector database.
"""Load document in vector database.
Args:
- chunks: document chunks.
chunks(List[Chunk]): document chunks.
Return:
- ids: chunks ids.
List[str]: chunk ids.
"""
pass
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""load document in vector database with limit.
"""Load document in vector database with specified limit.
Args:
chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once.
max_threads: Max number of threads to use.
chunks(List[Chunk]): Document chunks.
max_chunks_once_load(int): Max number of chunks to load at once.
max_threads(int): Max number of threads to use.
Return:
List[str]: Chunk ids.
"""
# Group the chunks into chunks of size max_chunks
chunk_groups = [
@@ -96,13 +106,15 @@ class VectorStoreBase(ABC):
return ids
@abstractmethod
def similar_search(self, text, topk) -> List[Chunk]:
"""similar search in vector database.
def similar_search(self, text: str, topk: int) -> List[Chunk]:
"""Similar search in vector database.
Args:
- text: query text
- topk: topk
text(str): The query text.
topk(int): The number of similar documents to return.
Return:
- chunks: chunks.
List[Chunk]: The similar documents.
"""
pass
@@ -110,38 +122,43 @@ class VectorStoreBase(ABC):
def similar_search_with_scores(
self, text, topk, score_threshold: float
) -> List[Chunk]:
"""similar search in vector database with scores.
"""Similar search with scores in vector database.
Args:
- text: query text
- topk: topk
- score_threshold: score_threshold: Optional, a floating point value between 0 to 1
text(str): The query text.
topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1
Return:
- chunks: chunks.
List[Chunk]: The similar documents.
"""
pass
@abstractmethod
def vector_name_exists(self) -> bool:
"""is vector store name exist."""
"""Whether vector name exists."""
return False
@abstractmethod
def delete_by_ids(self, ids):
"""delete vector by ids.
def delete_by_ids(self, ids: str):
"""Delete vectors by ids.
Args:
- ids: vector ids
ids(str): The ids of vectors to delete, separated by comma.
"""
@abstractmethod
def delete_vector_name(self, vector_name):
"""delete vector name.
def delete_vector_name(self, vector_name: str):
"""Delete vector by name.
Args:
- vector_name: vector store name
vector_name(str): The name of vector to delete.
"""
pass
def _normalization_vectors(self, vectors):
"""normalization vectors to scale[0,1]"""
"""Return L2-normalization vectors to scale[0,1].
Normalization vectors to scale[0,1].
"""
import numpy as np
norm = np.linalg.norm(vectors)

View File

@@ -1,14 +1,18 @@
"""Chroma vector store."""
import logging
import os
from typing import Any, List
from chromadb import PersistentClient
from chromadb.config import Settings
from pydantic import Field
from dbgpt._private.pydantic import Field
from dbgpt.configs.model_config import PILOT_PATH
# TODO: Recycle dependency on rag and storage
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from .base import VectorStoreBase, VectorStoreConfig
logger = logging.getLogger(__name__)
@@ -16,20 +20,28 @@ logger = logging.getLogger(__name__)
class ChromaVectorConfig(VectorStoreConfig):
"""Chroma vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
persist_path: str = Field(
default=os.getenv("CHROMA_PERSIST_PATH", None),
description="The password of vector store, if not set, will use the default password.",
description="The password of vector store, if not set, will use the default "
"password.",
)
collection_metadata: dict = Field(
default=None,
description="the index metadata of vector store, if not set, will use the default metadata.",
description="the index metadata of vector store, if not set, will use the "
"default metadata.",
)
class ChromaStore(VectorStoreBase):
"""chroma database"""
"""Chroma vector store."""
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
"""Create a ChromaStore instance."""
from langchain.vectorstores import Chroma
chroma_vector_config = vector_store_config.dict()
@@ -59,6 +71,7 @@ class ChromaStore(VectorStoreBase):
)
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
"""Search similar documents."""
logger.info("ChromaStore similar search")
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
return [
@@ -67,14 +80,16 @@ class ChromaStore(VectorStoreBase):
]
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
"""
"""Search similar documents with scores.
Chroma similar_search_with_score.
Return docs and relevance scores in the range [0, 1].
Args:
text(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
"""
logger.info("ChromaStore similar search with scores")
docs_and_scores = (
@@ -87,8 +102,8 @@ class ChromaStore(VectorStoreBase):
for doc, score in docs_and_scores
]
def vector_name_exists(self):
"""is vector store name exist."""
def vector_name_exists(self) -> bool:
"""Whether vector name exists."""
logger.info(f"Check persist_dir: {self.persist_dir}")
if not os.path.exists(self.persist_dir):
return False
@@ -98,6 +113,7 @@ class ChromaStore(VectorStoreBase):
return len(files) > 0
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document to vector store."""
logger.info("ChromaStore load document")
texts = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
@@ -105,14 +121,16 @@ class ChromaStore(VectorStoreBase):
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return ids
def delete_vector_name(self, vector_name):
def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
logger.info(f"chroma vector_name:{vector_name} begin delete...")
self.vector_store_client.delete_collection()
self._clean_persist_folder()
return True
def delete_by_ids(self, ids):
logger.info(f"begin delete chroma ids...")
"""Delete vector by ids."""
logger.info(f"begin delete chroma ids: {ids}")
ids = ids.split(",")
if len(ids) > 0:
collection = self.vector_store_client._collection

View File

@@ -1,19 +1,26 @@
"""Connector for vector store."""
import os
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Type, cast
from dbgpt.rag.chunk import Chunk
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
connector = {}
connector: Dict[str, Type] = {}
class VectorStoreConnector:
"""The connector for vector store.
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
2.similar_search: similarity search from vector_store
3.similar_search_with_scores: similarity search with similarity score from vector_store
VectorStoreConnector, can connect different vector db provided load document api_v1
and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus,
Weaviate).
2.similar_search: similarity search from vector_store.
3.similar_search_with_scores: similarity search with similarity score from
vector_store
code example:
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -23,9 +30,12 @@ class VectorStoreConnector:
"""
def __init__(
self, vector_store_type: str, vector_store_config: VectorStoreConfig = None
self,
vector_store_type: str,
vector_store_config: Optional[VectorStoreConfig] = None,
) -> None:
"""initialize vector store connector.
"""Create a VectorStoreConnector instance.
Args:
- vector_store_type: vector store type Milvus, Chroma, Weaviate
- ctx: vector store config params.
@@ -34,7 +44,7 @@ class VectorStoreConnector:
self._register()
if self._match(vector_store_type):
self.connector_class = connector.get(vector_store_type)
self.connector_class = connector[vector_store_type]
else:
raise Exception(f"Vector Store Type Not support. {0}", vector_store_type)
@@ -44,11 +54,11 @@ class VectorStoreConnector:
@classmethod
def from_default(
cls,
vector_store_type: str = None,
vector_store_type: Optional[str] = None,
embedding_fn: Optional[Any] = None,
vector_store_config: Optional[VectorStoreConfig] = None,
) -> "VectorStoreConnector":
"""initialize default vector store connector."""
"""Initialize default vector store connector."""
vector_store_type = vector_store_type or os.getenv(
"VECTOR_STORE_TYPE", "Chroma"
)
@@ -56,22 +66,33 @@ class VectorStoreConnector:
vector_store_config = vector_store_config or ChromaVectorConfig()
vector_store_config.embedding_fn = embedding_fn
return cls(vector_store_type, vector_store_config)
real_vector_store_type = cast(str, vector_store_type)
return cls(real_vector_store_type, vector_store_config)
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""load document in vector database.
"""Load document in vector database.
Args:
- chunks: document chunks.
Return chunk ids.
"""
max_chunks_once_load = (
self._vector_store_config.max_chunks_once_load
if self._vector_store_config
else 10
)
max_threads = (
self._vector_store_config.max_threads if self._vector_store_config else 1
)
return self.client.load_document_with_limit(
chunks,
self._vector_store_config.max_chunks_once_load,
self._vector_store_config.max_threads,
max_chunks_once_load,
max_threads,
)
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
"""similar search in vector database.
"""Similar search in vector database.
Args:
- doc: query text
- topk: topk
@@ -83,14 +104,17 @@ class VectorStoreConnector:
def similar_search_with_scores(
self, doc: str, topk: int, score_threshold: float
) -> List[Chunk]:
"""
"""Similar search with scores in vector database.
similar_search_with_score in vector database..
Return docs and relevance scores in the range [0, 1].
Args:
- doc(str): query text
- topk(int): return docs nums. Defaults to 4.
- score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
doc(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
Return:
- chunks: chunks.
"""
@@ -98,32 +122,33 @@ class VectorStoreConnector:
@property
def vector_store_config(self) -> VectorStoreConfig:
"""vector store config."""
"""Return the vector store config."""
if not self._vector_store_config:
raise ValueError("vector store config not set.")
return self._vector_store_config
def vector_name_exists(self):
"""is vector store name exist."""
"""Whether vector name exists."""
return self.client.vector_name_exists()
def delete_vector_name(self, vector_name):
"""vector store delete
def delete_vector_name(self, vector_name: str):
"""Delete vector name.
Args:
- vector_name: vector store name
"""
return self.client.delete_vector_name(vector_name)
def delete_by_ids(self, ids):
"""vector store delete by ids.
"""Delete vector by ids.
Args:
- ids: vector ids
"""
return self.client.delete_by_ids(ids=ids)
def _match(self, vector_store_type) -> bool:
if connector.get(vector_store_type):
return True
else:
return False
return bool(connector.get(vector_store_type))
def _register(self):
for cls in vector_store.__all__:

View File

@@ -1,13 +1,14 @@
"""Milvus vector store."""
from __future__ import annotations
import json
import logging
import os
from typing import Any, Iterable, List, Optional, Tuple
from typing import Any, Iterable, List, Optional
from pydantic import Field
from dbgpt.rag.chunk import Chunk, Document
from dbgpt._private.pydantic import Field
from dbgpt.core import Embeddings
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from dbgpt.util import string_utils
@@ -17,6 +18,11 @@ logger = logging.getLogger(__name__)
class MilvusVectorConfig(VectorStoreConfig):
"""Milvus vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
uri: str = Field(
default="localhost",
description="The uri of milvus store, if not set, will use the default uri.",
@@ -28,7 +34,8 @@ class MilvusVectorConfig(VectorStoreConfig):
alias: str = Field(
default="default",
description="The alias of milvus store, if not set, will use the default alias.",
description="The alias of milvus store, if not set, will use the default "
"alias.",
)
user: str = Field(
default=None,
@@ -36,35 +43,42 @@ class MilvusVectorConfig(VectorStoreConfig):
)
password: str = Field(
default=None,
description="The password of milvus store, if not set, will use the default password.",
description="The password of milvus store, if not set, will use the default "
"password.",
)
primary_field: str = Field(
default="pk_id",
description="The primary field of milvus store, if not set, will use the default primary field.",
description="The primary field of milvus store, if not set, will use the "
"default primary field.",
)
text_field: str = Field(
default="content",
description="The text field of milvus store, if not set, will use the default text field.",
description="The text field of milvus store, if not set, will use the default "
"text field.",
)
embedding_field: str = Field(
default="vector",
description="The embedding field of milvus store, if not set, will use the default embedding field.",
description="The embedding field of milvus store, if not set, will use the "
"default embedding field.",
)
metadata_field: str = Field(
default="metadata",
description="The metadata field of milvus store, if not set, will use the default metadata field.",
description="The metadata field of milvus store, if not set, will use the "
"default metadata field.",
)
secure: str = Field(
default="",
description="The secure of milvus store, if not set, will use the default secure.",
description="The secure of milvus store, if not set, will use the default "
"secure.",
)
class MilvusStore(VectorStoreBase):
"""Milvus database"""
"""Milvus vector store."""
def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
"""MilvusStore init.
"""Create a MilvusStore instance.
Args:
vector_store_config (MilvusVectorConfig): MilvusStore config.
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
@@ -93,8 +107,11 @@ class MilvusStore(VectorStoreBase):
hex_str = bytes_str.hex()
self.collection_name = hex_str
self.embedding = vector_store_config.embedding_fn
self.fields = []
if not vector_store_config.embedding_fn:
raise ValueError("embedding is required for MilvusStore")
self.embedding: Embeddings = vector_store_config.embedding_fn
self.fields: List = []
self.alias = milvus_vector_config.get("alias") or "default"
# use HNSW by default.
@@ -124,7 +141,8 @@ class MilvusStore(VectorStoreBase):
if (self.username is None) != (self.password is None):
raise ValueError(
"Both username and password must be set to use authentication for Milvus"
"Both username and password must be set to use authentication for "
"Milvus"
)
if self.username:
connect_kwargs["user"] = self.username
@@ -139,7 +157,10 @@ class MilvusStore(VectorStoreBase):
)
def init_schema_and_load(self, vector_name, documents) -> List[str]:
"""Create a Milvus collection, indexes it with HNSW, load document.
"""Create a Milvus collection.
Create a Milvus collection, indexes it with HNSW, load document.
Args:
vector_name (Embeddings): your collection name.
documents (List[str]): Text to insert.
@@ -155,7 +176,7 @@ class MilvusStore(VectorStoreBase):
connections,
utility,
)
from pymilvus.orm.types import infer_dtype_bydata
from pymilvus.orm.types import infer_dtype_bydata # noqa: F401
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
@@ -240,10 +261,10 @@ class MilvusStore(VectorStoreBase):
partition_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[str]:
"""add text data into Milvus."""
"""Add text data into Milvus."""
insert_dict: Any = {self.text_field: list(texts)}
try:
import numpy as np
import numpy as np # noqa: F401
text_vector = self.embedding.embed_documents(list(texts))
insert_dict[self.vector_field] = text_vector
@@ -268,7 +289,7 @@ class MilvusStore(VectorStoreBase):
return res.primary_keys
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""load document in vector database."""
"""Load document in vector database."""
batch_size = 500
batched_list = [
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
@@ -280,6 +301,7 @@ class MilvusStore(VectorStoreBase):
return doc_ids
def similar_search(self, text, topk) -> List[Chunk]:
"""Perform a search on a query string and return results."""
from pymilvus import Collection, DataType
"""similar_search in vector database."""
@@ -409,12 +431,14 @@ class MilvusStore(VectorStoreBase):
return ret[0], ret
def vector_name_exists(self):
"""Whether vector name exists."""
from pymilvus import utility
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name):
def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
from pymilvus import utility
"""milvus delete collection name"""
@@ -423,11 +447,12 @@ class MilvusStore(VectorStoreBase):
return True
def delete_by_ids(self, ids):
"""Delete vector by ids."""
from pymilvus import Collection
self.col = Collection(self.collection_name)
"""milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...")
# milvus delete vectors by ids
logger.info(f"begin delete milvus ids: {ids}")
delete_ids = ids.split(",")
doc_ids = [int(doc_id) for doc_id in delete_ids]
delet_expr = f"{self.primary_field} in {doc_ids}"

View File

@@ -1,9 +1,9 @@
"""Postgres vector store."""
import logging
from typing import Any, List
from pydantic import Field
from dbgpt._private.config import Config
from dbgpt._private.pydantic import Field
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
@@ -15,21 +15,26 @@ CFG = Config()
class PGVectorConfig(VectorStoreConfig):
"""PG vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
connection_string: str = Field(
default=None,
description="the connection string of vector store, if not set, will use the default connection string.",
description="the connection string of vector store, if not set, will use the "
"default connection string.",
)
class PGVectorStore(VectorStoreBase):
"""`Postgres.PGVector` vector store.
"""PG vector store.
To use this, you should have the ``pgvector`` python package installed.
"""
def __init__(self, vector_store_config: PGVectorConfig) -> None:
"""init pgvector storage"""
"""Create a PGVectorStore instance."""
from langchain.vectorstores import PGVector
self.connection_string = vector_store_config.connection_string
@@ -42,23 +47,43 @@ class PGVectorStore(VectorStoreBase):
connection_string=self.connection_string,
)
def similar_search(self, text, topk, **kwargs: Any) -> None:
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]:
"""Perform similar search in PGVector."""
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self):
def vector_name_exists(self) -> bool:
"""Check if vector name exists."""
try:
self.vector_store_client.create_collection()
return True
except Exception as e:
logger.error("vector_name_exists error", e.message)
logger.error(f"vector_name_exists error, {str(e)}")
return False
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document to PGVector.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
return self.vector_store_client.from_documents(lc_documents)
def delete_vector_name(self, vector_name):
def delete_vector_name(self, vector_name: str):
"""Delete vector by name.
Args:
vector_name(str): vector name.
"""
return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids):
def delete_by_ids(self, ids: str):
"""Delete vector by ids.
Args:
ids(str): vector ids, separated by comma.
"""
return self.vector_store_client.delete(ids)

View File

@@ -1,14 +1,13 @@
"""Weaviate vector store."""
import logging
import os
from typing import List
from langchain.schema import Document
from pydantic import Field
from dbgpt._private.config import Config
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt._private.pydantic import Field
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from .base import VectorStoreBase, VectorStoreConfig
logger = logging.getLogger(__name__)
CFG = Config()
@@ -17,6 +16,11 @@ CFG = Config()
class WeaviateVectorConfig(VectorStoreConfig):
"""Weaviate vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
weaviate_url: str = Field(
default=os.getenv("WEAVIATE_URL", None),
description="weaviate url address, if not set, will use the default url.",
@@ -28,7 +32,7 @@ class WeaviateVectorConfig(VectorStoreConfig):
class WeaviateStore(VectorStoreBase):
"""Weaviate database"""
"""Weaviate database."""
def __init__(self, vector_store_config: WeaviateVectorConfig) -> None:
"""Initialize with Weaviate client."""
@@ -49,8 +53,8 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> None:
"""Perform similar search in Weaviate"""
def similar_search(self, text: str, topk: int) -> List[Chunk]:
"""Perform similar search in Weaviate."""
logger.info("Weaviate similar search")
# nearText = {
# "concepts": [text],
@@ -68,15 +72,16 @@ class WeaviateStore(VectorStoreBase):
docs = []
for r in res:
docs.append(
Document(
page_content=r["page_content"],
Chunk(
content=r["page_content"],
metadata={"metadata": r["metadata"]},
)
)
return docs
def vector_name_exists(self) -> bool:
"""Check if a vector name exists for a given class in Weaviate.
"""Whether the vector name exists in Weaviate.
Returns:
bool: True if the vector name exists, False otherwise.
"""
@@ -85,14 +90,15 @@ class WeaviateStore(VectorStoreBase):
return True
return False
except Exception as e:
logger.error("vector_name_exists error", e.message)
logger.error(f"vector_name_exists error, {str(e)}")
return False
def _default_schema(self) -> None:
"""
Create the schema for Weaviate with a Document class containing metadata and text properties.
"""
"""Create default schema in Weaviate.
Create the schema for Weaviate with a Document class containing metadata and
text properties.
"""
schema = {
"classes": [
{
@@ -137,7 +143,7 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client.schema.create(schema)
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load documents into Weaviate"""
"""Load document to Weaviate."""
logger.info("Weaviate load document")
texts = [doc.content for doc in chunks]
metadatas = [doc.metadata for doc in chunks]
@@ -157,3 +163,5 @@ class WeaviateStore(VectorStoreBase):
data_object=properties, class_name=self.vector_name
)
self.vector_store_client.batch.flush()
# TODO: return ids
return []

View File

@@ -11,4 +11,5 @@ isort==5.10.1
pyupgrade==3.1.0
types-requests
types-beautifulsoup4
types-Markdown
types-Markdown
types-tqdm