refactor: rag storage refactor (#2434)

This commit is contained in:
Aries-ckt
2025-03-17 14:15:21 +08:00
committed by GitHub
parent 34d86d693c
commit fc3fe6b725
52 changed files with 1134 additions and 797 deletions

View File

@@ -10,7 +10,7 @@ from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt_ext.rag.assembler.base import BaseAssembler
from dbgpt_ext.rag.chunk_manager import ChunkParameters
from dbgpt_ext.rag.retriever.bm25 import BM25Retriever
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig
class BM25Assembler(BaseAssembler):
@@ -44,7 +44,8 @@ class BM25Assembler(BaseAssembler):
def __init__(
self,
knowledge: Knowledge,
es_config: ElasticsearchVectorConfig,
es_config: ElasticsearchStoreConfig,
name: Optional[str] = "dbgpt",
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
chunk_parameters: Optional[ChunkParameters] = None,
@@ -55,7 +56,7 @@ class BM25Assembler(BaseAssembler):
Args:
knowledge: (Knowledge) Knowledge datasource.
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
es_config: (ElasticsearchStoreConfig) Elasticsearch config.
k1 (Optional[float]): Controls non-linear term frequency normalization
(saturation). The default value is 2.0.
b (Optional[float]): Controls to what degree document length normalizes
@@ -70,7 +71,7 @@ class BM25Assembler(BaseAssembler):
self._es_port = es_config.port
self._es_username = es_config.user
self._es_password = es_config.password
self._index_name = es_config.name
self._index_name = name
self._k1 = k1
self._b = b
if self._es_username and self._es_password:
@@ -123,7 +124,8 @@ class BM25Assembler(BaseAssembler):
def load_from_knowledge(
cls,
knowledge: Knowledge,
es_config: ElasticsearchVectorConfig,
es_config: ElasticsearchStoreConfig,
name: Optional[str] = "dbgpt",
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
chunk_parameters: Optional[ChunkParameters] = None,
@@ -132,7 +134,8 @@ class BM25Assembler(BaseAssembler):
Args:
knowledge: (Knowledge) Knowledge datasource.
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
es_config: (ElasticsearchStoreConfig) Elasticsearch config.
name: (Optional[str]) BM25 name.
k1: (Optional[float]) BM25 parameter k1.
b: (Optional[float]) BM25 parameter b.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
@@ -144,6 +147,7 @@ class BM25Assembler(BaseAssembler):
return cls(
knowledge=knowledge,
es_config=es_config,
name=name,
k1=k1,
b=b,
chunk_parameters=chunk_parameters,
@@ -153,7 +157,8 @@ class BM25Assembler(BaseAssembler):
async def aload_from_knowledge(
cls,
knowledge: Knowledge,
es_config: ElasticsearchVectorConfig,
es_config: ElasticsearchStoreConfig,
name: Optional[str] = "dbgpt",
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
chunk_parameters: Optional[ChunkParameters] = None,
@@ -163,7 +168,7 @@ class BM25Assembler(BaseAssembler):
Args:
knowledge: (Knowledge) Knowledge datasource.
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
es_config: (ElasticsearchStoreConfig) Elasticsearch config.
k1: (Optional[float]) BM25 parameter k1.
b: (Optional[float]) BM25 parameter b.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
@@ -178,6 +183,7 @@ class BM25Assembler(BaseAssembler):
cls,
knowledge,
es_config=es_config,
name=name,
k1=k1,
b=b,
chunk_parameters=chunk_parameters,

View File

@@ -21,7 +21,7 @@ class DBSchemaAssembler(BaseAssembler):
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt_serve.rag.assembler.db_struct import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.connector import VectorStoreBase
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
connection = SQLiteTempConnector.create_temporary_db()

View File

@@ -2,8 +2,9 @@
import logging
import os
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
from dbgpt.core import Embeddings, LLMClient
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
@@ -33,79 +34,86 @@ class GraphRetriever(GraphRetrieverBase):
def __init__(
self,
config,
graph_store_adapter,
llm_client: Optional[LLMClient] = None,
llm_model: Optional[str] = None,
triplet_graph_enabled: Optional[bool] = True,
document_graph_enabled: Optional[bool] = True,
extract_top_k: Optional[int] = 5,
kg_chunk_search_top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = 5,
similarity_score_threshold: Optional[float] = 0.7,
embedding_fn: Optional[Embeddings] = None,
embedding_batch_size: Optional[int] = 20,
enable_text_search: Optional[bool] = False,
text2gql_model_enabled: Optional[bool] = False,
text2gql_model_name: Optional[str] = None,
):
"""Initialize Graph Retriever."""
self._triplet_graph_enabled = config.triplet_graph_enabled or (
self._triplet_graph_enabled = triplet_graph_enabled or (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
)
self._document_graph_enabled = config.document_graph_enabled or (
self._document_graph_enabled = document_graph_enabled or (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
)
triplet_topk = int(
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
extract_top_k or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
)
document_topk = int(
config.knowledge_graph_chunk_search_top_size
or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
kg_chunk_search_top_k or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
)
llm_client = config.llm_client
model_name = config.model_name
llm_client = llm_client
model_name = llm_model
self._enable_similarity_search = (
graph_store_adapter.graph_store.enable_similarity_search
)
self._embedding_batch_size = int(
config.knowledge_graph_embedding_batch_size
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
embedding_batch_size or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
)
similarity_search_topk = int(
config.similarity_search_topk
or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
similarity_top_k or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
)
similarity_search_score_threshold = float(
config.extract_score_threshold
similarity_score_threshold
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
)
self._enable_text_search = config.enable_text_search or (
self._enable_text_search = enable_text_search or (
os.getenv("TEXT_SEARCH_ENABLED", "").lower() == "true"
)
text2gql_model_enabled = config.text2gql_model_enabled or (
text2gql_model_enabled = text2gql_model_enabled or (
os.getenv("TEXT2GQL_MODEL_ENABLED", "").lower() == "true"
)
text2gql_model_name = config.text2gql_model_name or os.getenv(
"TEXT2GQL_MODEL_NAME"
text2gql_model_name = text2gql_model_name or os.getenv("TEXT2GQL_MODEL_NAME")
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
text2gql_model_name,
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else config.text2gql_model_enabled
else text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
config.text2gql_model_name,
text2gql_model_name,
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else config.text2gql_model_enabled
else text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
config.text2gql_model_name,
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else config.text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
config.text2gql_model_name,
text2gql_model_name,
)
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
self._text_embedder = TextEmbedder(config.embedding_fn)
self._text_embedder = TextEmbedder(embedding_fn)
intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
if text2gql_model_enabled:

View File

@@ -17,19 +17,27 @@ class GraphExtractor(LLMExtractor):
"""GraphExtractor class."""
def __init__(
self, llm_client: LLMClient, model_name: str, chunk_history: VectorStoreBase
self,
llm_client: LLMClient,
model_name: str,
chunk_history: VectorStoreBase,
index_name: str,
max_chunks_once_load: Optional[int] = 10,
max_threads: Optional[int] = 1,
top_k: Optional[int] = 5,
score_threshold: Optional[float] = 0.7,
):
"""Initialize the GraphExtractor."""
super().__init__(llm_client, model_name, GRAPH_EXTRACT_PT_CN)
self._chunk_history = chunk_history
config = self._chunk_history.get_config()
# config = self._chunk_history.get_config()
self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold
self._vector_space = index_name
self._max_chunks_once_load = max_chunks_once_load
self._max_threads = max_threads
self._topk = top_k
self._score_threshold = score_threshold
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load chunk context."""

View File

@@ -50,11 +50,11 @@ def _import_oceanbase() -> Tuple[Type, Type]:
def _import_elastic() -> Tuple[Type, Type]:
from dbgpt_ext.storage.vector_store.elastic_store import (
ElasticsearchVectorConfig,
ElasticsearchStoreConfig,
ElasticStore,
)
return ElasticStore, ElasticsearchVectorConfig
return ElasticStore, ElasticsearchStoreConfig
def _import_builtin_knowledge_graph() -> Tuple[Type, Type]:
@@ -115,6 +115,35 @@ def _select_rag_storage(name: str) -> Tuple[Type, Type]:
raise AttributeError(f"Could not find: {name}")
_HAS_SCAN = False
def scan_storage_configs():
"""Scan storage configs."""
from dbgpt.storage.base import IndexStoreConfig
from dbgpt.util.module_utils import ModelScanner, ScannerConfig
global _HAS_SCAN
if _HAS_SCAN:
return
modules = [
"dbgpt_ext.storage.vector_store",
"dbgpt_ext.storage.knowledge_graph",
"dbgpt_ext.storage.graph_store",
]
scanner = ModelScanner[IndexStoreConfig]()
for module in modules:
config = ScannerConfig(
module_path=module,
base_class=IndexStoreConfig,
)
scanner.scan_and_register(config)
_HAS_SCAN = True
return scanner.get_registered_items()
__vector_store__ = [
"Chroma",
"Milvus",

View File

@@ -11,21 +11,19 @@ from dbgpt.storage.full_text.base import FullTextStoreBase
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util import string_utils
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig
class ElasticDocumentConfig(ElasticsearchVectorConfig):
"""Elasticsearch document store config."""
k1: Optional[float] = 2.0
b: Optional[float] = 0.75
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig
class ElasticDocumentStore(FullTextStoreBase):
"""Elasticsearch index store."""
def __init__(
self, es_config: ElasticDocumentConfig, executor: Optional[Executor] = None
self,
es_config: ElasticsearchStoreConfig,
name: Optional[str] = "dbgpt",
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
executor: Optional[Executor] = None,
):
"""Init elasticsearch index store.
@@ -46,17 +44,17 @@ class ElasticDocumentStore(FullTextStoreBase):
self._es_password = es_config.password or os.getenv(
"ELASTICSEARCH_PASSWORD", "dbgpt"
)
self._index_name = es_config.name.lower()
if string_utils.contains_chinese(es_config.name):
bytes_str = es_config.name.encode("utf-8")
self._index_name = name.lower()
if string_utils.contains_chinese(name):
bytes_str = name.encode("utf-8")
hex_str = bytes_str.hex()
self._index_name = "dbgpt_" + hex_str
# k1 (Optional[float]): Controls non-linear term frequency normalization
# (saturation). The default value is 2.0.
self._k1 = es_config.k1 or 2.0
self._k1 = k1 or 2.0
# b (Optional[float]): Controls to what degree document length normalizes
# tf values. The default value is 0.75.
self._b = es_config.b or 0.75
self._b = b or 0.75
if self._es_username and self._es_password:
self._es_client = Elasticsearch(
hosts=[f"http://{self._es_url}:{self._es_port}"],

View File

@@ -4,9 +4,9 @@ import base64
import json
import logging
import os
from dataclasses import dataclass, field
from typing import List
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import GraphElemType
from dbgpt_ext.datasource.conn_tugraph import TuGraphConnector
@@ -14,66 +14,89 @@ from dbgpt_ext.datasource.conn_tugraph import TuGraphConnector
logger = logging.getLogger(__name__)
@dataclass
class TuGraphStoreConfig(GraphStoreConfig):
"""TuGraph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
__type__ = "TuGraph"
host: str = Field(
host: str = field(
default="127.0.0.1",
description="TuGraph host",
metadata={
"description": "TuGraph host",
},
)
port: int = Field(
port: int = field(
default=7687,
description="TuGraph port",
metadata={
"description": "TuGraph port",
},
)
username: str = Field(
username: str = field(
default="admin",
description="login username",
metadata={
"description": "login username",
},
)
password: str = Field(
password: str = field(
default="73@TuGraph",
description="login password",
metadata={
"description": "login password",
},
)
vertex_type: str = Field(
vertex_type: str = field(
default=GraphElemType.ENTITY.value,
description="The type of entity vertex, `entity` by default.",
metadata={
"description": "The type of vertex, `entity` by default.",
},
)
document_type: str = Field(
document_type: str = field(
default=GraphElemType.DOCUMENT.value,
description="The type of document vertex, `document` by default.",
metadata={
"description": "The type of document vertex, `document` by default.",
},
)
chunk_type: str = Field(
chunk_type: str = field(
default=GraphElemType.CHUNK.value,
description="The type of chunk vertex, `relation` by default.",
metadata={
"description": "The type of chunk vertex, `relation` by default.",
},
)
edge_type: str = Field(
edge_type: str = field(
default=GraphElemType.RELATION.value,
description="The type of relation edge, `relation` by default.",
metadata={
"description": "The type of relation edge, `relation` by default.",
},
)
include_type: str = Field(
include_type: str = field(
default=GraphElemType.INCLUDE.value,
description="The type of include edge, `include` by default.",
metadata={
"description": "The type of include edge, `include` by default.",
},
)
next_type: str = Field(
next_type: str = field(
default=GraphElemType.NEXT.value,
description="The type of next edge, `next` by default.",
metadata={
"description": "The type of next edge, `next` by default.",
},
)
plugin_names: List[str] = Field(
default=["leiden"],
description=(
"Plugins need to be loaded when initialize TuGraph, "
"code: https://github.com/TuGraph-family"
"/dbgpt-tugraph-plugins/tree/master/cpp"
),
plugin_names: List[str] = field(
default_factory=lambda: ["leiden"],
metadata={
"description": "The list of plugin names to be uploaded to the database.",
},
)
enable_summary: bool = Field(
enable_summary: bool = field(
default=True,
description="Enable graph community summary or not.",
metadata={
"description": "Enable graph community summary or not.",
},
)
enable_similarity_search: bool = Field(
enable_similarity_search: bool = field(
default=False,
description="Enable the similarity search or not",
metadata={
"description": "Enable the similarity search or not",
},
)

View File

@@ -18,18 +18,24 @@ class BuiltinCommunityMetastore(CommunityMetastore):
"""Builtin Community metastore."""
def __init__(
self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None
self,
vector_store: VectorStoreBase,
rdb_store: Optional[RDBMSConnector] = None,
index_name: Optional[str] = None,
max_chunks_once_load: Optional[int] = 10,
max_threads: Optional[int] = 1,
top_k: Optional[int] = 5,
score_threshold: Optional[float] = 0.7,
):
"""Initialize Community metastore."""
self._vector_store = vector_store
self._rdb_store = rdb_store
config = self._vector_store.get_config()
self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold
self._vector_space = index_name
self._max_chunks_once_load = max_chunks_once_load
self._max_threads = max_threads
self._topk = top_k
self._score_threshold = score_threshold
def get(self, community_id: str) -> Community:
"""Get community."""

View File

@@ -25,11 +25,23 @@ class CommunityStore:
graph_store_adapter: GraphStoreAdapter,
community_summarizer: CommunitySummarizer,
vector_store: VectorStoreBase,
index_name: Optional[str] = None,
max_chunks_once_load: Optional[int] = 10,
max_threads: Optional[int] = 1,
top_k: Optional[int] = 5,
score_threshold: Optional[float] = 0.7,
):
"""Initialize the CommunityStore class."""
self._graph_store_adapter = graph_store_adapter
self._community_summarizer = community_summarizer
self._meta_store = BuiltinCommunityMetastore(vector_store)
self._meta_store = BuiltinCommunityMetastore(
vector_store=vector_store,
index_name=index_name,
max_chunks_once_load=max_chunks_once_load,
max_threads=max_threads,
top_k=top_k,
score_threshold=score_threshold,
)
async def build_communities(self, batch_size: int = 1):
"""Discover communities."""

View File

@@ -5,9 +5,9 @@ import os
import uuid
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk, LLMClient
from dbgpt.core import Chunk, Embeddings, LLMClient
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.graph_store.base import GraphStoreConfig
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
@@ -17,13 +17,13 @@ from dbgpt_ext.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt_ext.rag.transformer.graph_embedder import GraphEmbedder
from dbgpt_ext.rag.transformer.graph_extractor import GraphExtractor
from dbgpt_ext.rag.transformer.text_embedder import TextEmbedder
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
from dbgpt_ext.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
GRAPH_PARAMETERS,
BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig,
)
from dbgpt_ext.storage.vector_store.factory import VectorStoreFactory
logger = logging.getLogger(__name__)
@@ -139,87 +139,6 @@ logger = logging.getLogger(__name__)
),
],
)
class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
"""Community summary knowledge graph config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
vector_store_type: str = Field(
default="Chroma",
description="The type of vector store.",
)
user: Optional[str] = Field(
default=None,
description="The user of vector store, if not set, will use the default user.",
)
password: Optional[str] = Field(
default=None,
description=(
"The password of vector store, if not set, will use the default password."
),
)
extract_topk: int = Field(
default=5,
description="Topk of knowledge graph extract",
)
extract_score_threshold: float = Field(
default=0.3,
description="Recall score of knowledge graph extract",
)
community_topk: int = Field(
default=50,
description="Topk of community search in knowledge graph",
)
community_score_threshold: float = Field(
default=0.3,
description="Recall score of community search in knowledge graph",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable the graph search for triplets",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable the graph search for documents and chunks",
)
knowledge_graph_chunk_search_top_size: int = Field(
default=5,
description="Top size of knowledge graph chunk search",
)
knowledge_graph_extraction_batch_size: int = Field(
default=20,
description="Batch size of triplets extraction from the text",
)
community_summary_batch_size: int = Field(
default=20,
description="Batch size of parallel community building process",
)
knowledge_graph_embedding_batch_size: int = Field(
default=20,
description="Batch size of triplets embedding from the text",
)
similarity_search_topk: int = Field(
default=5,
description="Topk of similarity search",
)
similarity_search_score_threshold: float = Field(
default=0.7,
description="Recall score of similarity search",
)
enable_text_search: bool = Field(
default=False,
description="Enable text2gql search or not.",
)
text2gql_model_enabled: bool = Field(
default=False,
description="Enable fine-tuned text2gql model for text2gql translation.",
)
text2gql_model_name: str = Field(
default=None,
description="LLM Model for text2gql translation.",
)
@register_resource(
_("Community Summary Knowledge Graph"),
"community_summary_knowledge_graph",
@@ -239,104 +158,137 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
"""Community summary knowledge graph class."""
def __init__(self, config: CommunitySummaryKnowledgeGraphConfig):
def __init__(
self,
config: GraphStoreConfig,
name: Optional[str] = "dbgpt",
llm_client: Optional[LLMClient] = None,
llm_model: Optional[str] = None,
kg_extract_top_k: Optional[int] = 5,
kg_extract_score_threshold: Optional[float] = 0.3,
kg_community_top_k: Optional[int] = 50,
kg_community_score_threshold: Optional[float] = 0.3,
kg_triplet_graph_enabled: Optional[bool] = True,
kg_document_graph_enabled: Optional[bool] = True,
kg_chunk_search_top_k: Optional[int] = 5,
kg_extraction_batch_size: Optional[int] = 3,
kg_community_summary_batch_size: Optional[int] = 20,
kg_embedding_batch_size: Optional[int] = 20,
kg_similarity_top_k: Optional[int] = 5,
kg_similarity_score_threshold: Optional[float] = 0.7,
kg_enable_text_search: Optional[float] = False,
kg_text2gql_model_enabled: Optional[bool] = False,
kg_text2gql_model_name: Optional[str] = None,
embedding_fn: Optional[Embeddings] = None,
vector_store_config: Optional["VectorStoreConfig"] = None,
kg_max_chunks_once_load: Optional[int] = 10,
kg_max_threads: Optional[int] = 1,
):
"""Initialize community summary knowledge graph class."""
super().__init__(config)
super().__init__(
config=config, name=name, llm_client=llm_client, llm_model=llm_model
)
self._config = config
self._vector_store_type = config.vector_store_type or os.getenv(
self._vector_store_type = config.get_type_value() or os.getenv(
"VECTOR_STORE_TYPE"
)
self._extract_topk = int(
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
kg_extract_top_k or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
)
self._extract_score_threshold = float(
config.extract_score_threshold
kg_extract_score_threshold
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
)
self._community_topk = int(
config.community_topk
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
kg_community_top_k or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
)
self._community_score_threshold = float(
config.community_score_threshold
kg_community_score_threshold
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE")
)
self._document_graph_enabled = config.document_graph_enabled or (
self._document_graph_enabled = kg_document_graph_enabled or (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
)
self._triplet_graph_enabled = config.triplet_graph_enabled or (
self._triplet_graph_enabled = kg_triplet_graph_enabled or (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
)
self._triplet_extraction_batch_size = int(
config.knowledge_graph_extraction_batch_size
kg_extraction_batch_size
or os.getenv("KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE")
)
self._triplet_embedding_batch_size = int(
config.knowledge_graph_embedding_batch_size
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
kg_embedding_batch_size or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
)
self._community_summary_batch_size = int(
config.community_summary_batch_size
or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
kg_community_summary_batch_size or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
)
def extractor_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name
cfg.embedding_fn = config.embedding_fn
cfg.max_chunks_once_load = config.max_chunks_once_load
cfg.max_threads = config.max_threads
cfg.user = config.user
cfg.password = config.password
cfg.topk = self._extract_topk
cfg.score_threshold = self._extract_score_threshold
self._embedding_fn = embedding_fn
self._vector_store_config = vector_store_config
self._graph_extractor = GraphExtractor(
self._llm_client,
self._model_name,
VectorStoreFactory.create(
self._vector_store_type,
config.name + "_CHUNK_HISTORY",
extractor_configure,
vector_store_config.create_store(
name=name + "_CHUNK_HISTORY", embedding_fn=embedding_fn
),
index_name=name,
max_chunks_once_load=kg_max_chunks_once_load,
max_threads=kg_max_threads,
top_k=kg_extract_top_k,
score_threshold=kg_extract_score_threshold,
)
self._graph_embedder = GraphEmbedder(self._config.embedding_fn)
self._text_embedder = TextEmbedder(self._config.embedding_fn)
self._graph_embedder = GraphEmbedder(embedding_fn)
self._text_embedder = TextEmbedder(embedding_fn)
def community_store_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name
cfg.embedding_fn = config.embedding_fn
cfg.max_chunks_once_load = config.max_chunks_once_load
cfg.max_threads = config.max_threads
cfg.user = config.user
cfg.password = config.password
cfg.topk = self._community_topk
cfg.score_threshold = self._community_score_threshold
# def community_store_configure(name: str, cfg: VectorStoreConfig):
# cfg.name = name
# cfg.embedding_fn = self._embedding_fn
# cfg.max_chunks_once_load = max_chunks_once_load
# cfg.max_threads = max_threads
# cfg.user = config.user
# cfg.password = config.password
# cfg.topk = self._community_topk
# cfg.score_threshold = self._community_score_threshold
self._community_store = CommunityStore(
self._graph_store_adapter,
CommunitySummarizer(self._llm_client, self._model_name),
VectorStoreFactory.create(
self._vector_store_type,
config.name + "_COMMUNITY_SUMMARY",
community_store_configure,
vector_store_config.create_store(
name=name + "_COMMUNITY_SUMMARY", embedding_fn=embedding_fn
),
index_name=name,
max_chunks_once_load=kg_max_chunks_once_load,
max_threads=kg_max_threads,
top_k=kg_community_top_k,
score_threshold=kg_extract_score_threshold,
)
self._graph_retriever = GraphRetriever(
config,
self._graph_store_adapter,
llm_client=llm_client,
llm_model=llm_model,
triplet_graph_enabled=kg_triplet_graph_enabled,
document_graph_enabled=kg_document_graph_enabled,
extract_top_k=kg_extract_top_k,
kg_chunk_search_top_k=kg_chunk_search_top_k,
similarity_top_k=kg_similarity_top_k,
similarity_score_threshold=kg_similarity_score_threshold,
embedding_fn=embedding_fn,
embedding_batch_size=kg_embedding_batch_size,
text2gql_model_enabled=kg_text2gql_model_enabled,
text2gql_model_name=kg_text2gql_model_name,
)
def get_config(self) -> BuiltinKnowledgeGraphConfig:
def get_config(self) -> TuGraphStoreConfig:
"""Get the knowledge graph config."""
return self._config
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist graph from the document file."""
if not self.vector_name_exists():
self._graph_store_adapter.create_graph(self.get_config().name)
self._graph_store_adapter.create_graph(self._graph_name)
await self._aload_document_graph(chunks)
await self._aload_triplet_graph(chunks)
await self._community_store.build_communities(
@@ -447,7 +399,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
chunk_name=doc_name,
)
# chunk.metadata = {"Header0": "title", "Header1": "title", ..., "source": "source_path"} # noqa: E501
# chunk.metadata = {"Header0": "title",
# "Header1": "title", ..., "source": "source_path"} # noqa: E501
for chunk_index, chunk in enumerate(chunks):
parent = None
directory_keys = list(chunk.metadata.keys())[
@@ -542,7 +495,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
self._graph_embedder.truncate()
logger.info("Truncate text embedder")
self._text_embedder.truncate()
return [self._config.name]
return [self._graph_name]
def delete_vector_name(self, index_name: str):
"""Delete knowledge graph."""

View File

@@ -3,9 +3,9 @@
import asyncio
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk, Embeddings, LLMClient
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
@@ -16,6 +16,7 @@ from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _
from dbgpt_ext.rag.transformer.triplet_extractor import TripletExtractor
from dbgpt_ext.storage.graph_store.factory import GraphStoreFactory
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
from dbgpt_ext.storage.knowledge_graph.community.base import GraphStoreAdapter
from dbgpt_ext.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
@@ -98,16 +99,19 @@ GRAPH_PARAMETERS = [
),
],
)
@dataclass
class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig):
"""Builtin knowledge graph config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
__type__ = "TuGraph"
llm_client: LLMClient = Field(default=None, description="The default llm client.")
llm_model: Optional[str] = field(
default=None, metadata={"description": "llm model name."}
)
model_name: str = Field(default=None, description="The name of llm model.")
type: str = Field(default="TuGraph", description="The type of graph store.")
graph_type: Optional[str] = field(
default="TuGraph", metadata={"description": "graph store type."}
)
@register_resource(
@@ -129,33 +133,38 @@ class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig):
class BuiltinKnowledgeGraph(KnowledgeGraphBase):
"""Builtin knowledge graph class."""
def __init__(self, config: BuiltinKnowledgeGraphConfig):
def __init__(
self,
config: GraphStoreConfig = None,
name: Optional[str] = "dbgpt",
llm_client: Optional[LLMClient] = None,
llm_model: Optional[str] = None,
):
"""Create builtin knowledge graph instance."""
super().__init__()
self._config = config
self._llm_client = config.llm_client
self._llm_client = llm_client
self._graph_name = name
if not self._llm_client:
raise ValueError("No llm client provided.")
self._model_name = config.model_name
self._model_name = llm_model
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
self._graph_store: GraphStoreBase = self.__init_graph_store(config)
self._graph_store_adapter: GraphStoreAdapter = self.__init_graph_store_adapter()
def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase:
def __init_graph_store(self, config: GraphStoreConfig) -> GraphStoreBase:
def configure(cfg: GraphStoreConfig):
cfg.name = config.name
cfg.embedding_fn = config.embedding_fn
cfg.name = self._graph_name
graph_store_type = config.type or os.getenv("GRAPH_STORE_TYPE")
return GraphStoreFactory.create(graph_store_type, configure, config.dict())
graph_store_type = config.get_type_value() or os.getenv("GRAPH_STORE_TYPE")
return GraphStoreFactory.create(graph_store_type, configure, config.to_dict())
def __init_graph_store_adapter(self):
return GraphStoreAdapterFactory.create(self._graph_store)
def get_config(self) -> BuiltinKnowledgeGraphConfig:
def get_config(self) -> TuGraphStoreConfig:
"""Get the knowledge graph config."""
return self._config
@@ -171,7 +180,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
# wait async tasks completed
if not self.vector_name_exists():
self._graph_store_adapter.create_graph(self.get_config().name)
self._graph_store_adapter.create_graph(self._graph_name)
tasks = [process_chunk(chunk) for chunk in chunks]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@@ -188,7 +197,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
List[str]: chunk ids.
"""
if not self.vector_name_exists():
self._graph_store_adapter.create_graph(self.get_config().name)
self._graph_store_adapter.create_graph(self._graph_name)
for chunk in chunks:
triplets = await self._triplet_extractor.extract(chunk.content)
for triplet in triplets:
@@ -257,7 +266,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def truncate(self) -> List[str]:
"""Truncate knowledge graph."""
logger.info(f"Truncate graph {self._config.name}")
logger.info(f"Truncate graph {self._graph_name}")
self._graph_store_adapter.truncate()
logger.info("Truncate keyword extractor")
@@ -266,7 +275,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
logger.info("Truncate triplet extractor")
self._triplet_extractor.truncate()
return [self._config.name]
return [self._graph_name]
def delete_vector_name(self, index_name: str):
"""Delete vector name."""
@@ -286,4 +295,4 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def vector_name_exists(self) -> bool:
"""Whether name exists."""
return self._graph_store_adapter.graph_store.is_exist(self._config.name)
return self._graph_store_adapter.graph_store.is_exist(self._graph_name)

View File

@@ -2,11 +2,11 @@
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.configs.model_config import PILOT_PATH, resolve_root_path
from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
_COMMON_PARAMETERS,
@@ -18,8 +18,6 @@ from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
CHROMA_COLLECTION_NAME = "langchain"
@register_resource(
_("Chroma Config"),
@@ -38,21 +36,29 @@ CHROMA_COLLECTION_NAME = "langchain"
),
],
)
@dataclass
class ChromaVectorConfig(VectorStoreConfig):
"""Chroma vector store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
__type__ = "Chroma"
persist_path: Optional[str] = Field(
persist_path: Optional[str] = field(
default=os.getenv("CHROMA_PERSIST_PATH", None),
description="the persist path of vector store.",
metadata={
"help": _("The persist path of vector store."),
},
)
collection_metadata: Optional[dict] = Field(
collection_metadata: Optional[dict] = field(
default=None,
description="the index metadata of vector store, if not set, will use the "
"default metadata.",
metadata={
"help": _("The metadata of collection."),
},
)
def create_store(self, **kwargs) -> "ChromaStore":
"""Create index store."""
return ChromaStore(vector_store_config=self, **kwargs)
@register_resource(
_("Chroma Vector Store"),
@@ -73,11 +79,22 @@ class ChromaVectorConfig(VectorStoreConfig):
class ChromaStore(VectorStoreBase):
"""Chroma vector store."""
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
def __init__(
self,
vector_store_config: ChromaVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
chroma_client: Optional["PersistentClient"] = None, # type: ignore # noqa
collection_metadata: Optional[dict] = None,
) -> None:
"""Create a ChromaStore instance.
Args:
vector_store_config(ChromaVectorConfig): vector store config.
name(str): collection name.
embedding_fn(Embeddings): embedding function.
chroma_client(PersistentClient): chroma client.
collection_metadata(dict): collection metadata.
"""
super().__init__()
self._vector_store_config = vector_store_config
@@ -85,28 +102,27 @@ class ChromaStore(VectorStoreBase):
from chromadb import PersistentClient, Settings
except ImportError:
raise ImportError("Please install chroma package first.")
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_vector_config = vector_store_config.to_dict()
chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data")
)
self.persist_dir = os.path.join(
resolve_root_path(chroma_path), vector_store_config.name + ".vectordb"
)
self.embeddings = vector_store_config.embedding_fn
self.persist_dir = os.path.join(resolve_root_path(chroma_path) + "/chromadb")
self.embeddings = embedding_fn
if not self.embeddings:
raise ValueError("Embeddings is None")
chroma_settings = Settings(
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
persist_directory=self.persist_dir,
anonymized_telemetry=False,
)
self._chroma_client = PersistentClient(
path=self.persist_dir, settings=chroma_settings
)
collection_metadata = chroma_vector_config.get("collection_metadata") or {
"hnsw:space": "cosine"
}
self._chroma_client = chroma_client
if not self._chroma_client:
self._chroma_client = PersistentClient(
path=self.persist_dir, settings=chroma_settings
)
collection_metadata = collection_metadata or {"hnsw:space": "cosine"}
self._collection = self._chroma_client.get_or_create_collection(
name=CHROMA_COLLECTION_NAME,
name=name,
embedding_function=None,
metadata=collection_metadata,
)

View File

@@ -4,9 +4,9 @@ from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional
from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
@@ -73,47 +73,63 @@ logger = logging.getLogger(__name__)
],
description=_("Elasticsearch vector config."),
)
class ElasticsearchVectorConfig(VectorStoreConfig):
@dataclass
class ElasticsearchStoreConfig(VectorStoreConfig):
"""Elasticsearch vector store config."""
class Config:
"""Config for BaseModel."""
__type__ = "ElasticSearch"
arbitrary_types_allowed = True
uri: str = Field(
uri: str = field(
default="localhost",
description="The uri of elasticsearch store, if not set, will use the default "
"uri.",
metadata={
"description": "The uri of elasticsearch store, if not set, "
"will use the default uri."
},
)
port: str = Field(
port: str = field(
default="9200",
description="The port of elasticsearch store, if not set, will use the default "
"port.",
metadata={
"description": "The port of elasticsearch store, if not set, will use the "
"default port."
},
)
alias: str = Field(
alias: str = field(
default="default",
description="The alias of elasticsearch store, if not set, will use the "
"default "
"alias.",
metadata={
"description": "The alias of elasticsearch store, if not set, will use the "
"default alias."
},
)
index_name: str = Field(
index_name: str = field(
default="index_name_test",
description="The index name of elasticsearch store, if not set, will use the "
"default index name.",
metadata={
"description": "The index name of elasticsearch store, if not set, will"
" use "
"the default index name."
},
)
metadata_field: str = Field(
metadata_field: str = field(
default="metadata",
description="The metadata field of elasticsearch store, if not set, will use "
"the default metadata field.",
metadata={
"description": "The metadata field of elasticsearch store, "
"if not set, will "
"use the default metadata field."
},
)
secure: str = Field(
secure: str = field(
default="",
description="The secure of elasticsearch store, if not set, will use the "
"default secure.",
metadata={
"description": "The secure of elasticsearch store, i"
"f not set, will use the "
"default secure."
},
)
def create_store(self, **kwargs) -> "ElasticStore":
"""Create Elastic store."""
return ElasticStore(vector_store_config=self, **kwargs)
@register_resource(
_("Elastic Vector Store"),
@@ -124,7 +140,7 @@ class ElasticsearchVectorConfig(VectorStoreConfig):
Parameter.build_from(
_("Elastic Config"),
"vector_store_config",
ElasticsearchVectorConfig,
ElasticsearchStoreConfig,
description=_("the elastic config of vector store."),
optional=True,
default=None,
@@ -134,11 +150,16 @@ class ElasticsearchVectorConfig(VectorStoreConfig):
class ElasticStore(VectorStoreBase):
"""Elasticsearch vector store."""
def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None:
def __init__(
self,
vector_store_config: ElasticsearchStoreConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
) -> None:
"""Create a ElasticsearchStore instance.
Args:
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
vector_store_config (ElasticsearchStoreConfig): ElasticsearchStore config.
"""
super().__init__()
self._vector_store_config = vector_store_config
@@ -158,21 +179,19 @@ class ElasticStore(VectorStoreBase):
"ELASTICSEARCH_PASSWORD"
) or elasticsearch_vector_config.get("password")
self.collection_name = (
elasticsearch_vector_config.get("name") or vector_store_config.name
)
self.collection_name = name
# name to hex
if string_utils.contains_chinese(self.collection_name):
bytes_str = self.collection_name.encode("utf-8")
hex_str = bytes_str.hex()
self.collection_name = hex_str
if vector_store_config.embedding_fn is None:
if embedding_fn is None:
# Perform runtime checks on self.embedding to
# ensure it has been correctly set and loaded
raise ValueError("embedding_fn is required for ElasticSearchStore")
# to lower case
self.index_name = self.collection_name.lower()
self.embedding: Embeddings = vector_store_config.embedding_fn
self.embedding: Embeddings = embedding_fn
self.fields: List = []
if (self.username is None) != (self.password is None):
@@ -251,7 +270,7 @@ class ElasticStore(VectorStoreBase):
except Exception as e:
logger.error(f"ElasticSearch connection failed: {e}")
def get_config(self) -> ElasticsearchVectorConfig:
def get_config(self) -> ElasticsearchStoreConfig:
"""Get the vector store config."""
return self._vector_store_config

View File

@@ -1,46 +0,0 @@
"""Vector store factory."""
import logging
from typing import Tuple, Type
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from dbgpt_ext.storage import __vector_store__ as vector_store_list
from dbgpt_ext.storage import _select_rag_storage
logger = logging.getLogger(__name__)
class VectorStoreFactory:
"""Factory for vector store."""
@staticmethod
def create(
vector_store_type: str, vector_space_name: str, vector_store_configure=None
) -> VectorStoreBase:
"""Create a VectorStore instance.
Args:
- vector_store_type: vector store type Chroma, Milvus, etc.
- vector_store_config: vector store config
"""
store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type)
try:
config = cfg_cls()
if vector_store_configure:
vector_store_configure(vector_space_name, config)
return store_cls(config)
except Exception as e:
logger.error("create vector store failed: %s", e)
raise e
@staticmethod
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
for t in vector_store_list:
if t.lower() == vector_store_type.lower():
store_cls, cfg_cls = _select_rag_storage(t)
if issubclass(store_cls, VectorStoreBase) and issubclass(
cfg_cls, VectorStoreConfig
):
return store_cls, cfg_cls
raise Exception(f"Vector store {vector_store_type} not supported")

View File

@@ -5,9 +5,9 @@ from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
@@ -94,51 +94,82 @@ logger = logging.getLogger(__name__)
],
description=_("Milvus vector config."),
)
@dataclass
class MilvusVectorConfig(VectorStoreConfig):
"""Milvus vector store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
__type__ = "Milvus"
uri: Optional[str] = Field(
uri: str = field(
default=None,
description="The uri of milvus store, if not set, will use the default uri.",
metadata={
"help": _("The uri of milvus store, if not set, will use the default uri.")
},
)
port: str = Field(
port: str = field(
default="19530",
description="The port of milvus store, if not set, will use the default port.",
metadata={
"help": _(
"The port of milvus store, if not set, will use the default port."
)
},
)
alias: str = Field(
alias: str = field(
default="default",
description="The alias of milvus store, if not set, will use the default "
"alias.",
metadata={
"help": _(
"The alias of milvus store, if not set, will use the default alias."
)
},
)
primary_field: str = Field(
primary_field: str = field(
default="pk_id",
description="The primary field of milvus store, if not set, will use the "
"default primary field.",
metadata={
"help": _(
"The primary field of milvus store, i"
"f not set, will use the default primary field."
)
},
)
text_field: str = Field(
text_field: str = field(
default="content",
description="The text field of milvus store, if not set, will use the default "
"text field.",
metadata={
"help": _(
"The text field of milvus store, if not set, will use the "
"default text field."
)
},
)
embedding_field: str = Field(
embedding_field: str = field(
default="vector",
description="The embedding field of milvus store, if not set, will use the "
"default embedding field.",
metadata={
"help": _(
"The embedding field of milvus store, if not set, will use the "
"default embedding field."
)
},
)
metadata_field: str = Field(
metadata_field: str = field(
default="metadata",
description="The metadata field of milvus store, if not set, will use the "
"default metadata field.",
metadata={
"help": _(
"The metadata field of milvus store, if not set, will use the "
"default metadata field."
)
},
)
secure: str = Field(
secure: str = field(
default="",
description="The secure of milvus store, if not set, will use the default "
"secure.",
metadata={
"help": _("The secure of milvus store, if not set, will use the default ")
},
)
def create_store(self, **kwargs) -> "MilvusStore":
"""Create Milvus Store."""
return MilvusStore(vector_store_config=self, **kwargs)
@register_resource(
_("Milvus Vector Store"),
@@ -159,7 +190,12 @@ class MilvusVectorConfig(VectorStoreConfig):
class MilvusStore(VectorStoreBase):
"""Milvus vector store."""
def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
def __init__(
self,
vector_store_config: MilvusVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
) -> None:
"""Create a MilvusStore instance.
Args:
@@ -190,18 +226,16 @@ class MilvusStore(VectorStoreBase):
)
self.secure = milvus_vector_config.get("secure") or os.getenv("MILVUS_SECURE")
self.collection_name = (
milvus_vector_config.get("name") or vector_store_config.name
)
self.collection_name = name
if string_utils.contains_chinese(self.collection_name):
bytes_str = self.collection_name.encode("utf-8")
hex_str = bytes_str.hex()
self.collection_name = hex_str
if vector_store_config.embedding_fn is None:
if embedding_fn is None:
# Perform runtime checks on self.embedding to
# ensure it has been correctly set and loaded
raise ValueError("embedding_fn is required for MilvusStore")
self.embedding: Embeddings = vector_store_config.embedding_fn
self.embedding: Embeddings = embedding_fn
self.fields: List = []
self.alias = milvus_vector_config.get("alias") or "default"

View File

@@ -5,14 +5,14 @@ import logging
import math
import os
import uuid
from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple
import numpy as np
from pydantic import Field
from sqlalchemy import JSON, Column, String, Table, func, text
from sqlalchemy.dialects.mysql import LONGTEXT
from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
_COMMON_PARAMETERS,
@@ -122,35 +122,48 @@ def _normalize(vector: List[float]) -> List[float]:
],
description="OceanBase vector store config.",
)
@dataclass
class OceanBaseConfig(VectorStoreConfig):
"""OceanBase vector store config."""
class Config:
"""Config for BaseModel."""
__type__ = "OceanBase"
arbitrary_types_allowed = True
ob_host: Optional[str] = field(
default=None,
metadata={
"help": "The host of oceanbase, if not set, will use the default host."
},
)
ob_port: Optional[int] = field(
default=None,
metadata={
"help": "The port of oceanbase, if not set, will use the default port."
},
)
ob_user: Optional[str] = field(
default=None,
metadata={
"help": "The user of oceanbase, if not set, will use the default user."
},
)
ob_password: Optional[str] = field(
default=None,
metadata={
"help": "The password of oceanbase, if not set, "
"will use the default password"
},
)
ob_database: Optional[str] = field(
default=None,
metadata={
"help": "The database for vector tables, if not set, "
"will use the default database."
},
)
"""OceanBase config"""
ob_host: Optional[str] = Field(
default=None,
description="oceanbase host",
)
ob_port: Optional[int] = Field(
default=None,
description="oceanbase port",
)
ob_user: Optional[str] = Field(
default=None,
description="user to login",
)
ob_password: Optional[str] = Field(
default=None,
description="password to login",
)
ob_database: Optional[str] = Field(
default=None,
description="database for vector tables",
)
def create_store(self, **kwargs) -> "OceanBaseStore":
"""Create OceanBase store."""
return OceanBaseStore(vector_store_config=self, **kwargs)
@register_resource(
@@ -172,7 +185,12 @@ class OceanBaseConfig(VectorStoreConfig):
class OceanBaseStore(VectorStoreBase):
"""OceanBase vector store."""
def __init__(self, vector_store_config: OceanBaseConfig) -> None:
def __init__(
self,
vector_store_config: OceanBaseConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
) -> None:
"""Create a OceanBaseStore instance."""
try:
from pyobvector import ObVecClient # type: ignore
@@ -188,8 +206,8 @@ class OceanBaseStore(VectorStoreBase):
super().__init__()
self._vector_store_config = vector_store_config
self.embedding_function = vector_store_config.embedding_fn
self.table_name = vector_store_config.name
self.embedding_function = embedding_fn
self.table_name = name
vector_store_config_map = vector_store_config.to_dict()
OB_HOST = str(

View File

@@ -1,10 +1,10 @@
"""Postgres vector store."""
import logging
from dataclasses import dataclass, field
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
_COMMON_PARAMETERS,
@@ -37,17 +37,22 @@ logger = logging.getLogger(__name__)
],
description="PG vector config.",
)
@dataclass
class PGVectorConfig(VectorStoreConfig):
"""PG vector store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
connection_string: str = Field(
connection_string: str = field(
default=None,
description="the connection string of vector store, if not set, will use the "
"default connection string.",
metadata={
"description": "the connection string of vector store, "
"if not set, will use the default connection string."
},
)
def create_store(self, **kwargs) -> "PGVectorStore":
"""Create Milvus Store."""
return PGVectorStore(vector_store_config=self, **kwargs)
@register_resource(
_("PG Vector Store"),
@@ -71,7 +76,12 @@ class PGVectorStore(VectorStoreBase):
To use this, you should have the ``pgvector`` python package installed.
"""
def __init__(self, vector_store_config: PGVectorConfig) -> None:
def __init__(
self,
vector_store_config: PGVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
) -> None:
"""Create a PGVectorStore instance."""
try:
from langchain.vectorstores import PGVector # mypy: ignore
@@ -83,8 +93,8 @@ class PGVectorStore(VectorStoreBase):
self._vector_store_config = vector_store_config
self.connection_string = vector_store_config.connection_string
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
self.embeddings = embedding_fn
self.collection_name = name
self.vector_store_client = PGVector(
embedding_function=self.embeddings, # type: ignore

View File

@@ -2,10 +2,10 @@
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
_COMMON_PARAMETERS,
@@ -45,20 +45,28 @@ logger = logging.getLogger(__name__)
),
],
)
@dataclass
class WeaviateVectorConfig(VectorStoreConfig):
"""Weaviate vector store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
weaviate_url: str = Field(
weaviate_url: str = field(
default=os.getenv("WEAVIATE_URL", None),
description="weaviate url address, if not set, will use the default url.",
metadata={
"description": "weaviate url address, if not set, "
"will use the default url.",
},
)
persist_path: str = Field(
persist_path: str = field(
default=os.getenv("WEAVIATE_PERSIST_PATH", None),
description="weaviate persist path.",
metadata={
"description": "weaviate persist path.",
},
)
def create_store(self, **kwargs) -> "WeaviateStore":
"""Create Weaviate Store store."""
return WeaviateStore(vector_store_config=self, **kwargs)
@register_resource(
_("Weaviate Vector Store"),
@@ -79,7 +87,12 @@ class WeaviateVectorConfig(VectorStoreConfig):
class WeaviateStore(VectorStoreBase):
"""Weaviate database."""
def __init__(self, vector_store_config: WeaviateVectorConfig) -> None:
def __init__(
self,
vector_store_config: WeaviateVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
) -> None:
"""Initialize with Weaviate client."""
try:
import weaviate
@@ -92,10 +105,10 @@ class WeaviateStore(VectorStoreBase):
self._vector_store_config = vector_store_config
self.weaviate_url = vector_store_config.weaviate_url
self.embedding = vector_store_config.embedding_fn
self.vector_name = vector_store_config.name
self.embedding = embedding_fn
self.vector_name = name
self.persist_dir = os.path.join(
vector_store_config.persist_path, vector_store_config.name + ".vectordb"
vector_store_config.persist_path, name + ".vectordb"
)
self.vector_store_client = weaviate.Client(self.weaviate_url)