mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 13:29:35 +00:00
refactor: rag storage refactor (#2434)
This commit is contained in:
@@ -10,7 +10,7 @@ from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt_ext.rag.assembler.base import BaseAssembler
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt_ext.rag.retriever.bm25 import BM25Retriever
|
||||
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig
|
||||
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig
|
||||
|
||||
|
||||
class BM25Assembler(BaseAssembler):
|
||||
@@ -44,7 +44,8 @@ class BM25Assembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
es_config: ElasticsearchVectorConfig,
|
||||
es_config: ElasticsearchStoreConfig,
|
||||
name: Optional[str] = "dbgpt",
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
@@ -55,7 +56,7 @@ class BM25Assembler(BaseAssembler):
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
|
||||
es_config: (ElasticsearchStoreConfig) Elasticsearch config.
|
||||
k1 (Optional[float]): Controls non-linear term frequency normalization
|
||||
(saturation). The default value is 2.0.
|
||||
b (Optional[float]): Controls to what degree document length normalizes
|
||||
@@ -70,7 +71,7 @@ class BM25Assembler(BaseAssembler):
|
||||
self._es_port = es_config.port
|
||||
self._es_username = es_config.user
|
||||
self._es_password = es_config.password
|
||||
self._index_name = es_config.name
|
||||
self._index_name = name
|
||||
self._k1 = k1
|
||||
self._b = b
|
||||
if self._es_username and self._es_password:
|
||||
@@ -123,7 +124,8 @@ class BM25Assembler(BaseAssembler):
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
es_config: ElasticsearchVectorConfig,
|
||||
es_config: ElasticsearchStoreConfig,
|
||||
name: Optional[str] = "dbgpt",
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
@@ -132,7 +134,8 @@ class BM25Assembler(BaseAssembler):
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
|
||||
es_config: (ElasticsearchStoreConfig) Elasticsearch config.
|
||||
name: (Optional[str]) BM25 name.
|
||||
k1: (Optional[float]) BM25 parameter k1.
|
||||
b: (Optional[float]) BM25 parameter b.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
@@ -144,6 +147,7 @@ class BM25Assembler(BaseAssembler):
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
es_config=es_config,
|
||||
name=name,
|
||||
k1=k1,
|
||||
b=b,
|
||||
chunk_parameters=chunk_parameters,
|
||||
@@ -153,7 +157,8 @@ class BM25Assembler(BaseAssembler):
|
||||
async def aload_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
es_config: ElasticsearchVectorConfig,
|
||||
es_config: ElasticsearchStoreConfig,
|
||||
name: Optional[str] = "dbgpt",
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
@@ -163,7 +168,7 @@ class BM25Assembler(BaseAssembler):
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
|
||||
es_config: (ElasticsearchStoreConfig) Elasticsearch config.
|
||||
k1: (Optional[float]) BM25 parameter k1.
|
||||
b: (Optional[float]) BM25 parameter b.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
@@ -178,6 +183,7 @@ class BM25Assembler(BaseAssembler):
|
||||
cls,
|
||||
knowledge,
|
||||
es_config=es_config,
|
||||
name=name,
|
||||
k1=k1,
|
||||
b=b,
|
||||
chunk_parameters=chunk_parameters,
|
||||
|
@@ -21,7 +21,7 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt_serve.rag.assembler.db_struct import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreBase
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
connection = SQLiteTempConnector.create_temporary_db()
|
||||
|
@@ -2,8 +2,9 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.core import Embeddings, LLMClient
|
||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
|
||||
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
||||
@@ -33,79 +34,86 @@ class GraphRetriever(GraphRetrieverBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
graph_store_adapter,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
llm_model: Optional[str] = None,
|
||||
triplet_graph_enabled: Optional[bool] = True,
|
||||
document_graph_enabled: Optional[bool] = True,
|
||||
extract_top_k: Optional[int] = 5,
|
||||
kg_chunk_search_top_k: Optional[int] = 5,
|
||||
similarity_top_k: Optional[int] = 5,
|
||||
similarity_score_threshold: Optional[float] = 0.7,
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
embedding_batch_size: Optional[int] = 20,
|
||||
enable_text_search: Optional[bool] = False,
|
||||
text2gql_model_enabled: Optional[bool] = False,
|
||||
text2gql_model_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize Graph Retriever."""
|
||||
self._triplet_graph_enabled = config.triplet_graph_enabled or (
|
||||
self._triplet_graph_enabled = triplet_graph_enabled or (
|
||||
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
self._document_graph_enabled = config.document_graph_enabled or (
|
||||
self._document_graph_enabled = document_graph_enabled or (
|
||||
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
triplet_topk = int(
|
||||
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||
extract_top_k or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||
)
|
||||
document_topk = int(
|
||||
config.knowledge_graph_chunk_search_top_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
|
||||
kg_chunk_search_top_k or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
|
||||
)
|
||||
llm_client = config.llm_client
|
||||
model_name = config.model_name
|
||||
llm_client = llm_client
|
||||
model_name = llm_model
|
||||
self._enable_similarity_search = (
|
||||
graph_store_adapter.graph_store.enable_similarity_search
|
||||
)
|
||||
self._embedding_batch_size = int(
|
||||
config.knowledge_graph_embedding_batch_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||
embedding_batch_size or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||
)
|
||||
similarity_search_topk = int(
|
||||
config.similarity_search_topk
|
||||
or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
|
||||
similarity_top_k or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
|
||||
)
|
||||
similarity_search_score_threshold = float(
|
||||
config.extract_score_threshold
|
||||
similarity_score_threshold
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
|
||||
)
|
||||
self._enable_text_search = config.enable_text_search or (
|
||||
self._enable_text_search = enable_text_search or (
|
||||
os.getenv("TEXT_SEARCH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
text2gql_model_enabled = config.text2gql_model_enabled or (
|
||||
text2gql_model_enabled = text2gql_model_enabled or (
|
||||
os.getenv("TEXT2GQL_MODEL_ENABLED", "").lower() == "true"
|
||||
)
|
||||
text2gql_model_name = config.text2gql_model_name or os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME"
|
||||
text2gql_model_name = text2gql_model_name or os.getenv("TEXT2GQL_MODEL_NAME")
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
text2gql_model_name,
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else config.text2gql_model_enabled
|
||||
else text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
config.text2gql_model_name,
|
||||
text2gql_model_name,
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else config.text2gql_model_enabled
|
||||
else text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
config.text2gql_model_name,
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else config.text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
config.text2gql_model_name,
|
||||
text2gql_model_name,
|
||||
)
|
||||
|
||||
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
|
||||
self._text_embedder = TextEmbedder(config.embedding_fn)
|
||||
self._text_embedder = TextEmbedder(embedding_fn)
|
||||
|
||||
intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
|
||||
if text2gql_model_enabled:
|
||||
|
@@ -17,19 +17,27 @@ class GraphExtractor(LLMExtractor):
|
||||
"""GraphExtractor class."""
|
||||
|
||||
def __init__(
|
||||
self, llm_client: LLMClient, model_name: str, chunk_history: VectorStoreBase
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
model_name: str,
|
||||
chunk_history: VectorStoreBase,
|
||||
index_name: str,
|
||||
max_chunks_once_load: Optional[int] = 10,
|
||||
max_threads: Optional[int] = 1,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = 0.7,
|
||||
):
|
||||
"""Initialize the GraphExtractor."""
|
||||
super().__init__(llm_client, model_name, GRAPH_EXTRACT_PT_CN)
|
||||
self._chunk_history = chunk_history
|
||||
|
||||
config = self._chunk_history.get_config()
|
||||
# config = self._chunk_history.get_config()
|
||||
|
||||
self._vector_space = config.name
|
||||
self._max_chunks_once_load = config.max_chunks_once_load
|
||||
self._max_threads = config.max_threads
|
||||
self._topk = config.topk
|
||||
self._score_threshold = config.score_threshold
|
||||
self._vector_space = index_name
|
||||
self._max_chunks_once_load = max_chunks_once_load
|
||||
self._max_threads = max_threads
|
||||
self._topk = top_k
|
||||
self._score_threshold = score_threshold
|
||||
|
||||
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
|
||||
"""Load chunk context."""
|
||||
|
@@ -50,11 +50,11 @@ def _import_oceanbase() -> Tuple[Type, Type]:
|
||||
|
||||
def _import_elastic() -> Tuple[Type, Type]:
|
||||
from dbgpt_ext.storage.vector_store.elastic_store import (
|
||||
ElasticsearchVectorConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
ElasticStore,
|
||||
)
|
||||
|
||||
return ElasticStore, ElasticsearchVectorConfig
|
||||
return ElasticStore, ElasticsearchStoreConfig
|
||||
|
||||
|
||||
def _import_builtin_knowledge_graph() -> Tuple[Type, Type]:
|
||||
@@ -115,6 +115,35 @@ def _select_rag_storage(name: str) -> Tuple[Type, Type]:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
|
||||
_HAS_SCAN = False
|
||||
|
||||
|
||||
def scan_storage_configs():
|
||||
"""Scan storage configs."""
|
||||
from dbgpt.storage.base import IndexStoreConfig
|
||||
from dbgpt.util.module_utils import ModelScanner, ScannerConfig
|
||||
|
||||
global _HAS_SCAN
|
||||
|
||||
if _HAS_SCAN:
|
||||
return
|
||||
modules = [
|
||||
"dbgpt_ext.storage.vector_store",
|
||||
"dbgpt_ext.storage.knowledge_graph",
|
||||
"dbgpt_ext.storage.graph_store",
|
||||
]
|
||||
|
||||
scanner = ModelScanner[IndexStoreConfig]()
|
||||
for module in modules:
|
||||
config = ScannerConfig(
|
||||
module_path=module,
|
||||
base_class=IndexStoreConfig,
|
||||
)
|
||||
scanner.scan_and_register(config)
|
||||
_HAS_SCAN = True
|
||||
return scanner.get_registered_items()
|
||||
|
||||
|
||||
__vector_store__ = [
|
||||
"Chroma",
|
||||
"Milvus",
|
||||
|
@@ -11,21 +11,19 @@ from dbgpt.storage.full_text.base import FullTextStoreBase
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util import string_utils
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig
|
||||
|
||||
|
||||
class ElasticDocumentConfig(ElasticsearchVectorConfig):
|
||||
"""Elasticsearch document store config."""
|
||||
|
||||
k1: Optional[float] = 2.0
|
||||
b: Optional[float] = 0.75
|
||||
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig
|
||||
|
||||
|
||||
class ElasticDocumentStore(FullTextStoreBase):
|
||||
"""Elasticsearch index store."""
|
||||
|
||||
def __init__(
|
||||
self, es_config: ElasticDocumentConfig, executor: Optional[Executor] = None
|
||||
self,
|
||||
es_config: ElasticsearchStoreConfig,
|
||||
name: Optional[str] = "dbgpt",
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
"""Init elasticsearch index store.
|
||||
|
||||
@@ -46,17 +44,17 @@ class ElasticDocumentStore(FullTextStoreBase):
|
||||
self._es_password = es_config.password or os.getenv(
|
||||
"ELASTICSEARCH_PASSWORD", "dbgpt"
|
||||
)
|
||||
self._index_name = es_config.name.lower()
|
||||
if string_utils.contains_chinese(es_config.name):
|
||||
bytes_str = es_config.name.encode("utf-8")
|
||||
self._index_name = name.lower()
|
||||
if string_utils.contains_chinese(name):
|
||||
bytes_str = name.encode("utf-8")
|
||||
hex_str = bytes_str.hex()
|
||||
self._index_name = "dbgpt_" + hex_str
|
||||
# k1 (Optional[float]): Controls non-linear term frequency normalization
|
||||
# (saturation). The default value is 2.0.
|
||||
self._k1 = es_config.k1 or 2.0
|
||||
self._k1 = k1 or 2.0
|
||||
# b (Optional[float]): Controls to what degree document length normalizes
|
||||
# tf values. The default value is 0.75.
|
||||
self._b = es_config.b or 0.75
|
||||
self._b = b or 0.75
|
||||
if self._es_username and self._es_password:
|
||||
self._es_client = Elasticsearch(
|
||||
hosts=[f"http://{self._es_url}:{self._es_port}"],
|
||||
|
@@ -4,9 +4,9 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import GraphElemType
|
||||
from dbgpt_ext.datasource.conn_tugraph import TuGraphConnector
|
||||
@@ -14,66 +14,89 @@ from dbgpt_ext.datasource.conn_tugraph import TuGraphConnector
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TuGraphStoreConfig(GraphStoreConfig):
|
||||
"""TuGraph store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
__type__ = "TuGraph"
|
||||
|
||||
host: str = Field(
|
||||
host: str = field(
|
||||
default="127.0.0.1",
|
||||
description="TuGraph host",
|
||||
metadata={
|
||||
"description": "TuGraph host",
|
||||
},
|
||||
)
|
||||
port: int = Field(
|
||||
port: int = field(
|
||||
default=7687,
|
||||
description="TuGraph port",
|
||||
metadata={
|
||||
"description": "TuGraph port",
|
||||
},
|
||||
)
|
||||
username: str = Field(
|
||||
username: str = field(
|
||||
default="admin",
|
||||
description="login username",
|
||||
metadata={
|
||||
"description": "login username",
|
||||
},
|
||||
)
|
||||
password: str = Field(
|
||||
password: str = field(
|
||||
default="73@TuGraph",
|
||||
description="login password",
|
||||
metadata={
|
||||
"description": "login password",
|
||||
},
|
||||
)
|
||||
vertex_type: str = Field(
|
||||
vertex_type: str = field(
|
||||
default=GraphElemType.ENTITY.value,
|
||||
description="The type of entity vertex, `entity` by default.",
|
||||
metadata={
|
||||
"description": "The type of vertex, `entity` by default.",
|
||||
},
|
||||
)
|
||||
document_type: str = Field(
|
||||
document_type: str = field(
|
||||
default=GraphElemType.DOCUMENT.value,
|
||||
description="The type of document vertex, `document` by default.",
|
||||
metadata={
|
||||
"description": "The type of document vertex, `document` by default.",
|
||||
},
|
||||
)
|
||||
chunk_type: str = Field(
|
||||
chunk_type: str = field(
|
||||
default=GraphElemType.CHUNK.value,
|
||||
description="The type of chunk vertex, `relation` by default.",
|
||||
metadata={
|
||||
"description": "The type of chunk vertex, `relation` by default.",
|
||||
},
|
||||
)
|
||||
edge_type: str = Field(
|
||||
edge_type: str = field(
|
||||
default=GraphElemType.RELATION.value,
|
||||
description="The type of relation edge, `relation` by default.",
|
||||
metadata={
|
||||
"description": "The type of relation edge, `relation` by default.",
|
||||
},
|
||||
)
|
||||
include_type: str = Field(
|
||||
include_type: str = field(
|
||||
default=GraphElemType.INCLUDE.value,
|
||||
description="The type of include edge, `include` by default.",
|
||||
metadata={
|
||||
"description": "The type of include edge, `include` by default.",
|
||||
},
|
||||
)
|
||||
next_type: str = Field(
|
||||
next_type: str = field(
|
||||
default=GraphElemType.NEXT.value,
|
||||
description="The type of next edge, `next` by default.",
|
||||
metadata={
|
||||
"description": "The type of next edge, `next` by default.",
|
||||
},
|
||||
)
|
||||
plugin_names: List[str] = Field(
|
||||
default=["leiden"],
|
||||
description=(
|
||||
"Plugins need to be loaded when initialize TuGraph, "
|
||||
"code: https://github.com/TuGraph-family"
|
||||
"/dbgpt-tugraph-plugins/tree/master/cpp"
|
||||
),
|
||||
plugin_names: List[str] = field(
|
||||
default_factory=lambda: ["leiden"],
|
||||
metadata={
|
||||
"description": "The list of plugin names to be uploaded to the database.",
|
||||
},
|
||||
)
|
||||
enable_summary: bool = Field(
|
||||
enable_summary: bool = field(
|
||||
default=True,
|
||||
description="Enable graph community summary or not.",
|
||||
metadata={
|
||||
"description": "Enable graph community summary or not.",
|
||||
},
|
||||
)
|
||||
enable_similarity_search: bool = Field(
|
||||
enable_similarity_search: bool = field(
|
||||
default=False,
|
||||
description="Enable the similarity search or not",
|
||||
metadata={
|
||||
"description": "Enable the similarity search or not",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
@@ -18,18 +18,24 @@ class BuiltinCommunityMetastore(CommunityMetastore):
|
||||
"""Builtin Community metastore."""
|
||||
|
||||
def __init__(
|
||||
self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None
|
||||
self,
|
||||
vector_store: VectorStoreBase,
|
||||
rdb_store: Optional[RDBMSConnector] = None,
|
||||
index_name: Optional[str] = None,
|
||||
max_chunks_once_load: Optional[int] = 10,
|
||||
max_threads: Optional[int] = 1,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = 0.7,
|
||||
):
|
||||
"""Initialize Community metastore."""
|
||||
self._vector_store = vector_store
|
||||
self._rdb_store = rdb_store
|
||||
|
||||
config = self._vector_store.get_config()
|
||||
self._vector_space = config.name
|
||||
self._max_chunks_once_load = config.max_chunks_once_load
|
||||
self._max_threads = config.max_threads
|
||||
self._topk = config.topk
|
||||
self._score_threshold = config.score_threshold
|
||||
self._vector_space = index_name
|
||||
self._max_chunks_once_load = max_chunks_once_load
|
||||
self._max_threads = max_threads
|
||||
self._topk = top_k
|
||||
self._score_threshold = score_threshold
|
||||
|
||||
def get(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
@@ -25,11 +25,23 @@ class CommunityStore:
|
||||
graph_store_adapter: GraphStoreAdapter,
|
||||
community_summarizer: CommunitySummarizer,
|
||||
vector_store: VectorStoreBase,
|
||||
index_name: Optional[str] = None,
|
||||
max_chunks_once_load: Optional[int] = 10,
|
||||
max_threads: Optional[int] = 1,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = 0.7,
|
||||
):
|
||||
"""Initialize the CommunityStore class."""
|
||||
self._graph_store_adapter = graph_store_adapter
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
self._meta_store = BuiltinCommunityMetastore(
|
||||
vector_store=vector_store,
|
||||
index_name=index_name,
|
||||
max_chunks_once_load=max_chunks_once_load,
|
||||
max_threads=max_threads,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def build_communities(self, batch_size: int = 1):
|
||||
"""Discover communities."""
|
||||
|
@@ -5,9 +5,9 @@ import os
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.core import Chunk, Embeddings, LLMClient
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.graph_store.base import GraphStoreConfig
|
||||
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
@@ -17,13 +17,13 @@ from dbgpt_ext.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt_ext.rag.transformer.graph_embedder import GraphEmbedder
|
||||
from dbgpt_ext.rag.transformer.graph_extractor import GraphExtractor
|
||||
from dbgpt_ext.rag.transformer.text_embedder import TextEmbedder
|
||||
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
|
||||
from dbgpt_ext.storage.knowledge_graph.community.community_store import CommunityStore
|
||||
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
|
||||
GRAPH_PARAMETERS,
|
||||
BuiltinKnowledgeGraph,
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
)
|
||||
from dbgpt_ext.storage.vector_store.factory import VectorStoreFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -139,87 +139,6 @@ logger = logging.getLogger(__name__)
|
||||
),
|
||||
],
|
||||
)
|
||||
class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
"""Community summary knowledge graph config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
vector_store_type: str = Field(
|
||||
default="Chroma",
|
||||
description="The type of vector store.",
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The user of vector store, if not set, will use the default user.",
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
extract_topk: int = Field(
|
||||
default=5,
|
||||
description="Topk of knowledge graph extract",
|
||||
)
|
||||
extract_score_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Recall score of knowledge graph extract",
|
||||
)
|
||||
community_topk: int = Field(
|
||||
default=50,
|
||||
description="Topk of community search in knowledge graph",
|
||||
)
|
||||
community_score_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Recall score of community search in knowledge graph",
|
||||
)
|
||||
triplet_graph_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable the graph search for triplets",
|
||||
)
|
||||
document_graph_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable the graph search for documents and chunks",
|
||||
)
|
||||
knowledge_graph_chunk_search_top_size: int = Field(
|
||||
default=5,
|
||||
description="Top size of knowledge graph chunk search",
|
||||
)
|
||||
knowledge_graph_extraction_batch_size: int = Field(
|
||||
default=20,
|
||||
description="Batch size of triplets extraction from the text",
|
||||
)
|
||||
community_summary_batch_size: int = Field(
|
||||
default=20,
|
||||
description="Batch size of parallel community building process",
|
||||
)
|
||||
knowledge_graph_embedding_batch_size: int = Field(
|
||||
default=20,
|
||||
description="Batch size of triplets embedding from the text",
|
||||
)
|
||||
similarity_search_topk: int = Field(
|
||||
default=5,
|
||||
description="Topk of similarity search",
|
||||
)
|
||||
similarity_search_score_threshold: float = Field(
|
||||
default=0.7,
|
||||
description="Recall score of similarity search",
|
||||
)
|
||||
enable_text_search: bool = Field(
|
||||
default=False,
|
||||
description="Enable text2gql search or not.",
|
||||
)
|
||||
text2gql_model_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable fine-tuned text2gql model for text2gql translation.",
|
||||
)
|
||||
text2gql_model_name: str = Field(
|
||||
default=None,
|
||||
description="LLM Model for text2gql translation.",
|
||||
)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Community Summary Knowledge Graph"),
|
||||
"community_summary_knowledge_graph",
|
||||
@@ -239,104 +158,137 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
"""Community summary knowledge graph class."""
|
||||
|
||||
def __init__(self, config: CommunitySummaryKnowledgeGraphConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GraphStoreConfig,
|
||||
name: Optional[str] = "dbgpt",
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
llm_model: Optional[str] = None,
|
||||
kg_extract_top_k: Optional[int] = 5,
|
||||
kg_extract_score_threshold: Optional[float] = 0.3,
|
||||
kg_community_top_k: Optional[int] = 50,
|
||||
kg_community_score_threshold: Optional[float] = 0.3,
|
||||
kg_triplet_graph_enabled: Optional[bool] = True,
|
||||
kg_document_graph_enabled: Optional[bool] = True,
|
||||
kg_chunk_search_top_k: Optional[int] = 5,
|
||||
kg_extraction_batch_size: Optional[int] = 3,
|
||||
kg_community_summary_batch_size: Optional[int] = 20,
|
||||
kg_embedding_batch_size: Optional[int] = 20,
|
||||
kg_similarity_top_k: Optional[int] = 5,
|
||||
kg_similarity_score_threshold: Optional[float] = 0.7,
|
||||
kg_enable_text_search: Optional[float] = False,
|
||||
kg_text2gql_model_enabled: Optional[bool] = False,
|
||||
kg_text2gql_model_name: Optional[str] = None,
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
vector_store_config: Optional["VectorStoreConfig"] = None,
|
||||
kg_max_chunks_once_load: Optional[int] = 10,
|
||||
kg_max_threads: Optional[int] = 1,
|
||||
):
|
||||
"""Initialize community summary knowledge graph class."""
|
||||
super().__init__(config)
|
||||
super().__init__(
|
||||
config=config, name=name, llm_client=llm_client, llm_model=llm_model
|
||||
)
|
||||
self._config = config
|
||||
|
||||
self._vector_store_type = config.vector_store_type or os.getenv(
|
||||
self._vector_store_type = config.get_type_value() or os.getenv(
|
||||
"VECTOR_STORE_TYPE"
|
||||
)
|
||||
self._extract_topk = int(
|
||||
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||
kg_extract_top_k or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||
)
|
||||
self._extract_score_threshold = float(
|
||||
config.extract_score_threshold
|
||||
kg_extract_score_threshold
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
|
||||
)
|
||||
self._community_topk = int(
|
||||
config.community_topk
|
||||
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
|
||||
kg_community_top_k or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
|
||||
)
|
||||
self._community_score_threshold = float(
|
||||
config.community_score_threshold
|
||||
kg_community_score_threshold
|
||||
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE")
|
||||
)
|
||||
self._document_graph_enabled = config.document_graph_enabled or (
|
||||
self._document_graph_enabled = kg_document_graph_enabled or (
|
||||
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
self._triplet_graph_enabled = config.triplet_graph_enabled or (
|
||||
self._triplet_graph_enabled = kg_triplet_graph_enabled or (
|
||||
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
self._triplet_extraction_batch_size = int(
|
||||
config.knowledge_graph_extraction_batch_size
|
||||
kg_extraction_batch_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE")
|
||||
)
|
||||
self._triplet_embedding_batch_size = int(
|
||||
config.knowledge_graph_embedding_batch_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||
kg_embedding_batch_size or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||
)
|
||||
self._community_summary_batch_size = int(
|
||||
config.community_summary_batch_size
|
||||
or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
|
||||
kg_community_summary_batch_size or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
|
||||
)
|
||||
|
||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
cfg.max_chunks_once_load = config.max_chunks_once_load
|
||||
cfg.max_threads = config.max_threads
|
||||
cfg.user = config.user
|
||||
cfg.password = config.password
|
||||
cfg.topk = self._extract_topk
|
||||
cfg.score_threshold = self._extract_score_threshold
|
||||
self._embedding_fn = embedding_fn
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self._graph_extractor = GraphExtractor(
|
||||
self._llm_client,
|
||||
self._model_name,
|
||||
VectorStoreFactory.create(
|
||||
self._vector_store_type,
|
||||
config.name + "_CHUNK_HISTORY",
|
||||
extractor_configure,
|
||||
vector_store_config.create_store(
|
||||
name=name + "_CHUNK_HISTORY", embedding_fn=embedding_fn
|
||||
),
|
||||
index_name=name,
|
||||
max_chunks_once_load=kg_max_chunks_once_load,
|
||||
max_threads=kg_max_threads,
|
||||
top_k=kg_extract_top_k,
|
||||
score_threshold=kg_extract_score_threshold,
|
||||
)
|
||||
|
||||
self._graph_embedder = GraphEmbedder(self._config.embedding_fn)
|
||||
self._text_embedder = TextEmbedder(self._config.embedding_fn)
|
||||
self._graph_embedder = GraphEmbedder(embedding_fn)
|
||||
self._text_embedder = TextEmbedder(embedding_fn)
|
||||
|
||||
def community_store_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
cfg.max_chunks_once_load = config.max_chunks_once_load
|
||||
cfg.max_threads = config.max_threads
|
||||
cfg.user = config.user
|
||||
cfg.password = config.password
|
||||
cfg.topk = self._community_topk
|
||||
cfg.score_threshold = self._community_score_threshold
|
||||
# def community_store_configure(name: str, cfg: VectorStoreConfig):
|
||||
# cfg.name = name
|
||||
# cfg.embedding_fn = self._embedding_fn
|
||||
# cfg.max_chunks_once_load = max_chunks_once_load
|
||||
# cfg.max_threads = max_threads
|
||||
# cfg.user = config.user
|
||||
# cfg.password = config.password
|
||||
# cfg.topk = self._community_topk
|
||||
# cfg.score_threshold = self._community_score_threshold
|
||||
|
||||
self._community_store = CommunityStore(
|
||||
self._graph_store_adapter,
|
||||
CommunitySummarizer(self._llm_client, self._model_name),
|
||||
VectorStoreFactory.create(
|
||||
self._vector_store_type,
|
||||
config.name + "_COMMUNITY_SUMMARY",
|
||||
community_store_configure,
|
||||
vector_store_config.create_store(
|
||||
name=name + "_COMMUNITY_SUMMARY", embedding_fn=embedding_fn
|
||||
),
|
||||
index_name=name,
|
||||
max_chunks_once_load=kg_max_chunks_once_load,
|
||||
max_threads=kg_max_threads,
|
||||
top_k=kg_community_top_k,
|
||||
score_threshold=kg_extract_score_threshold,
|
||||
)
|
||||
|
||||
self._graph_retriever = GraphRetriever(
|
||||
config,
|
||||
self._graph_store_adapter,
|
||||
llm_client=llm_client,
|
||||
llm_model=llm_model,
|
||||
triplet_graph_enabled=kg_triplet_graph_enabled,
|
||||
document_graph_enabled=kg_document_graph_enabled,
|
||||
extract_top_k=kg_extract_top_k,
|
||||
kg_chunk_search_top_k=kg_chunk_search_top_k,
|
||||
similarity_top_k=kg_similarity_top_k,
|
||||
similarity_score_threshold=kg_similarity_score_threshold,
|
||||
embedding_fn=embedding_fn,
|
||||
embedding_batch_size=kg_embedding_batch_size,
|
||||
text2gql_model_enabled=kg_text2gql_model_enabled,
|
||||
text2gql_model_name=kg_text2gql_model_name,
|
||||
)
|
||||
|
||||
def get_config(self) -> BuiltinKnowledgeGraphConfig:
|
||||
def get_config(self) -> TuGraphStoreConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
return self._config
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist graph from the document file."""
|
||||
if not self.vector_name_exists():
|
||||
self._graph_store_adapter.create_graph(self.get_config().name)
|
||||
self._graph_store_adapter.create_graph(self._graph_name)
|
||||
await self._aload_document_graph(chunks)
|
||||
await self._aload_triplet_graph(chunks)
|
||||
await self._community_store.build_communities(
|
||||
@@ -447,7 +399,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
chunk_name=doc_name,
|
||||
)
|
||||
|
||||
# chunk.metadata = {"Header0": "title", "Header1": "title", ..., "source": "source_path"} # noqa: E501
|
||||
# chunk.metadata = {"Header0": "title",
|
||||
# "Header1": "title", ..., "source": "source_path"} # noqa: E501
|
||||
for chunk_index, chunk in enumerate(chunks):
|
||||
parent = None
|
||||
directory_keys = list(chunk.metadata.keys())[
|
||||
@@ -542,7 +495,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
self._graph_embedder.truncate()
|
||||
logger.info("Truncate text embedder")
|
||||
self._text_embedder.truncate()
|
||||
return [self._config.name]
|
||||
return [self._graph_name]
|
||||
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete knowledge graph."""
|
||||
|
@@ -3,9 +3,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk, Embeddings, LLMClient
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||
@@ -16,6 +16,7 @@ from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt_ext.rag.transformer.triplet_extractor import TripletExtractor
|
||||
from dbgpt_ext.storage.graph_store.factory import GraphStoreFactory
|
||||
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
|
||||
from dbgpt_ext.storage.knowledge_graph.community.base import GraphStoreAdapter
|
||||
from dbgpt_ext.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
|
||||
|
||||
@@ -98,16 +99,19 @@ GRAPH_PARAMETERS = [
|
||||
),
|
||||
],
|
||||
)
|
||||
@dataclass
|
||||
class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig):
|
||||
"""Builtin knowledge graph config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
__type__ = "TuGraph"
|
||||
|
||||
llm_client: LLMClient = Field(default=None, description="The default llm client.")
|
||||
llm_model: Optional[str] = field(
|
||||
default=None, metadata={"description": "llm model name."}
|
||||
)
|
||||
|
||||
model_name: str = Field(default=None, description="The name of llm model.")
|
||||
|
||||
type: str = Field(default="TuGraph", description="The type of graph store.")
|
||||
graph_type: Optional[str] = field(
|
||||
default="TuGraph", metadata={"description": "graph store type."}
|
||||
)
|
||||
|
||||
|
||||
@register_resource(
|
||||
@@ -129,33 +133,38 @@ class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig):
|
||||
class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
"""Builtin knowledge graph class."""
|
||||
|
||||
def __init__(self, config: BuiltinKnowledgeGraphConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GraphStoreConfig = None,
|
||||
name: Optional[str] = "dbgpt",
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
llm_model: Optional[str] = None,
|
||||
):
|
||||
"""Create builtin knowledge graph instance."""
|
||||
super().__init__()
|
||||
self._config = config
|
||||
|
||||
self._llm_client = config.llm_client
|
||||
self._llm_client = llm_client
|
||||
self._graph_name = name
|
||||
if not self._llm_client:
|
||||
raise ValueError("No llm client provided.")
|
||||
|
||||
self._model_name = config.model_name
|
||||
self._model_name = llm_model
|
||||
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
|
||||
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
|
||||
self._graph_store: GraphStoreBase = self.__init_graph_store(config)
|
||||
self._graph_store_adapter: GraphStoreAdapter = self.__init_graph_store_adapter()
|
||||
|
||||
def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase:
|
||||
def __init_graph_store(self, config: GraphStoreConfig) -> GraphStoreBase:
|
||||
def configure(cfg: GraphStoreConfig):
|
||||
cfg.name = config.name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
cfg.name = self._graph_name
|
||||
|
||||
graph_store_type = config.type or os.getenv("GRAPH_STORE_TYPE")
|
||||
return GraphStoreFactory.create(graph_store_type, configure, config.dict())
|
||||
graph_store_type = config.get_type_value() or os.getenv("GRAPH_STORE_TYPE")
|
||||
return GraphStoreFactory.create(graph_store_type, configure, config.to_dict())
|
||||
|
||||
def __init_graph_store_adapter(self):
|
||||
return GraphStoreAdapterFactory.create(self._graph_store)
|
||||
|
||||
def get_config(self) -> BuiltinKnowledgeGraphConfig:
|
||||
def get_config(self) -> TuGraphStoreConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
return self._config
|
||||
|
||||
@@ -171,7 +180,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
# wait async tasks completed
|
||||
if not self.vector_name_exists():
|
||||
self._graph_store_adapter.create_graph(self.get_config().name)
|
||||
self._graph_store_adapter.create_graph(self._graph_name)
|
||||
tasks = [process_chunk(chunk) for chunk in chunks]
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
@@ -188,7 +197,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
if not self.vector_name_exists():
|
||||
self._graph_store_adapter.create_graph(self.get_config().name)
|
||||
self._graph_store_adapter.create_graph(self._graph_name)
|
||||
for chunk in chunks:
|
||||
triplets = await self._triplet_extractor.extract(chunk.content)
|
||||
for triplet in triplets:
|
||||
@@ -257,7 +266,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate knowledge graph."""
|
||||
logger.info(f"Truncate graph {self._config.name}")
|
||||
logger.info(f"Truncate graph {self._graph_name}")
|
||||
self._graph_store_adapter.truncate()
|
||||
|
||||
logger.info("Truncate keyword extractor")
|
||||
@@ -266,7 +275,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
logger.info("Truncate triplet extractor")
|
||||
self._triplet_extractor.truncate()
|
||||
|
||||
return [self._config.name]
|
||||
return [self._graph_name]
|
||||
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete vector name."""
|
||||
@@ -286,4 +295,4 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Whether name exists."""
|
||||
return self._graph_store_adapter.graph_store.is_exist(self._config.name)
|
||||
return self._graph_store_adapter.graph_store.is_exist(self._graph_name)
|
||||
|
@@ -2,11 +2,11 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.configs.model_config import PILOT_PATH, resolve_root_path
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
_COMMON_PARAMETERS,
|
||||
@@ -18,8 +18,6 @@ from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHROMA_COLLECTION_NAME = "langchain"
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Chroma Config"),
|
||||
@@ -38,21 +36,29 @@ CHROMA_COLLECTION_NAME = "langchain"
|
||||
),
|
||||
],
|
||||
)
|
||||
@dataclass
|
||||
class ChromaVectorConfig(VectorStoreConfig):
|
||||
"""Chroma vector store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
__type__ = "Chroma"
|
||||
|
||||
persist_path: Optional[str] = Field(
|
||||
persist_path: Optional[str] = field(
|
||||
default=os.getenv("CHROMA_PERSIST_PATH", None),
|
||||
description="the persist path of vector store.",
|
||||
metadata={
|
||||
"help": _("The persist path of vector store."),
|
||||
},
|
||||
)
|
||||
collection_metadata: Optional[dict] = Field(
|
||||
collection_metadata: Optional[dict] = field(
|
||||
default=None,
|
||||
description="the index metadata of vector store, if not set, will use the "
|
||||
"default metadata.",
|
||||
metadata={
|
||||
"help": _("The metadata of collection."),
|
||||
},
|
||||
)
|
||||
|
||||
def create_store(self, **kwargs) -> "ChromaStore":
|
||||
"""Create index store."""
|
||||
return ChromaStore(vector_store_config=self, **kwargs)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Chroma Vector Store"),
|
||||
@@ -73,11 +79,22 @@ class ChromaVectorConfig(VectorStoreConfig):
|
||||
class ChromaStore(VectorStoreBase):
|
||||
"""Chroma vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_config: ChromaVectorConfig,
|
||||
name: Optional[str],
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
chroma_client: Optional["PersistentClient"] = None, # type: ignore # noqa
|
||||
collection_metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Create a ChromaStore instance.
|
||||
|
||||
Args:
|
||||
vector_store_config(ChromaVectorConfig): vector store config.
|
||||
name(str): collection name.
|
||||
embedding_fn(Embeddings): embedding function.
|
||||
chroma_client(PersistentClient): chroma client.
|
||||
collection_metadata(dict): collection metadata.
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
@@ -85,28 +102,27 @@ class ChromaStore(VectorStoreBase):
|
||||
from chromadb import PersistentClient, Settings
|
||||
except ImportError:
|
||||
raise ImportError("Please install chroma package first.")
|
||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||
chroma_vector_config = vector_store_config.to_dict()
|
||||
chroma_path = chroma_vector_config.get(
|
||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||
)
|
||||
self.persist_dir = os.path.join(
|
||||
resolve_root_path(chroma_path), vector_store_config.name + ".vectordb"
|
||||
)
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.persist_dir = os.path.join(resolve_root_path(chroma_path) + "/chromadb")
|
||||
self.embeddings = embedding_fn
|
||||
if not self.embeddings:
|
||||
raise ValueError("Embeddings is None")
|
||||
chroma_settings = Settings(
|
||||
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
|
||||
persist_directory=self.persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
)
|
||||
self._chroma_client = PersistentClient(
|
||||
path=self.persist_dir, settings=chroma_settings
|
||||
)
|
||||
|
||||
collection_metadata = chroma_vector_config.get("collection_metadata") or {
|
||||
"hnsw:space": "cosine"
|
||||
}
|
||||
self._chroma_client = chroma_client
|
||||
if not self._chroma_client:
|
||||
self._chroma_client = PersistentClient(
|
||||
path=self.persist_dir, settings=chroma_settings
|
||||
)
|
||||
collection_metadata = collection_metadata or {"hnsw:space": "cosine"}
|
||||
self._collection = self._chroma_client.get_or_create_collection(
|
||||
name=CHROMA_COLLECTION_NAME,
|
||||
name=name,
|
||||
embedding_function=None,
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
|
@@ -4,9 +4,9 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
@@ -73,47 +73,63 @@ logger = logging.getLogger(__name__)
|
||||
],
|
||||
description=_("Elasticsearch vector config."),
|
||||
)
|
||||
class ElasticsearchVectorConfig(VectorStoreConfig):
|
||||
@dataclass
|
||||
class ElasticsearchStoreConfig(VectorStoreConfig):
|
||||
"""Elasticsearch vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
__type__ = "ElasticSearch"
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
uri: str = Field(
|
||||
uri: str = field(
|
||||
default="localhost",
|
||||
description="The uri of elasticsearch store, if not set, will use the default "
|
||||
"uri.",
|
||||
metadata={
|
||||
"description": "The uri of elasticsearch store, if not set, "
|
||||
"will use the default uri."
|
||||
},
|
||||
)
|
||||
port: str = Field(
|
||||
port: str = field(
|
||||
default="9200",
|
||||
description="The port of elasticsearch store, if not set, will use the default "
|
||||
"port.",
|
||||
metadata={
|
||||
"description": "The port of elasticsearch store, if not set, will use the "
|
||||
"default port."
|
||||
},
|
||||
)
|
||||
|
||||
alias: str = Field(
|
||||
alias: str = field(
|
||||
default="default",
|
||||
description="The alias of elasticsearch store, if not set, will use the "
|
||||
"default "
|
||||
"alias.",
|
||||
metadata={
|
||||
"description": "The alias of elasticsearch store, if not set, will use the "
|
||||
"default alias."
|
||||
},
|
||||
)
|
||||
index_name: str = Field(
|
||||
index_name: str = field(
|
||||
default="index_name_test",
|
||||
description="The index name of elasticsearch store, if not set, will use the "
|
||||
"default index name.",
|
||||
metadata={
|
||||
"description": "The index name of elasticsearch store, if not set, will"
|
||||
" use "
|
||||
"the default index name."
|
||||
},
|
||||
)
|
||||
metadata_field: str = Field(
|
||||
metadata_field: str = field(
|
||||
default="metadata",
|
||||
description="The metadata field of elasticsearch store, if not set, will use "
|
||||
"the default metadata field.",
|
||||
metadata={
|
||||
"description": "The metadata field of elasticsearch store, "
|
||||
"if not set, will "
|
||||
"use the default metadata field."
|
||||
},
|
||||
)
|
||||
secure: str = Field(
|
||||
secure: str = field(
|
||||
default="",
|
||||
description="The secure of elasticsearch store, if not set, will use the "
|
||||
"default secure.",
|
||||
metadata={
|
||||
"description": "The secure of elasticsearch store, i"
|
||||
"f not set, will use the "
|
||||
"default secure."
|
||||
},
|
||||
)
|
||||
|
||||
def create_store(self, **kwargs) -> "ElasticStore":
|
||||
"""Create Elastic store."""
|
||||
return ElasticStore(vector_store_config=self, **kwargs)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Elastic Vector Store"),
|
||||
@@ -124,7 +140,7 @@ class ElasticsearchVectorConfig(VectorStoreConfig):
|
||||
Parameter.build_from(
|
||||
_("Elastic Config"),
|
||||
"vector_store_config",
|
||||
ElasticsearchVectorConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
description=_("the elastic config of vector store."),
|
||||
optional=True,
|
||||
default=None,
|
||||
@@ -134,11 +150,16 @@ class ElasticsearchVectorConfig(VectorStoreConfig):
|
||||
class ElasticStore(VectorStoreBase):
|
||||
"""Elasticsearch vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_config: ElasticsearchStoreConfig,
|
||||
name: Optional[str],
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""Create a ElasticsearchStore instance.
|
||||
|
||||
Args:
|
||||
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
||||
vector_store_config (ElasticsearchStoreConfig): ElasticsearchStore config.
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
@@ -158,21 +179,19 @@ class ElasticStore(VectorStoreBase):
|
||||
"ELASTICSEARCH_PASSWORD"
|
||||
) or elasticsearch_vector_config.get("password")
|
||||
|
||||
self.collection_name = (
|
||||
elasticsearch_vector_config.get("name") or vector_store_config.name
|
||||
)
|
||||
self.collection_name = name
|
||||
# name to hex
|
||||
if string_utils.contains_chinese(self.collection_name):
|
||||
bytes_str = self.collection_name.encode("utf-8")
|
||||
hex_str = bytes_str.hex()
|
||||
self.collection_name = hex_str
|
||||
if vector_store_config.embedding_fn is None:
|
||||
if embedding_fn is None:
|
||||
# Perform runtime checks on self.embedding to
|
||||
# ensure it has been correctly set and loaded
|
||||
raise ValueError("embedding_fn is required for ElasticSearchStore")
|
||||
# to lower case
|
||||
self.index_name = self.collection_name.lower()
|
||||
self.embedding: Embeddings = vector_store_config.embedding_fn
|
||||
self.embedding: Embeddings = embedding_fn
|
||||
self.fields: List = []
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
@@ -251,7 +270,7 @@ class ElasticStore(VectorStoreBase):
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch connection failed: {e}")
|
||||
|
||||
def get_config(self) -> ElasticsearchVectorConfig:
|
||||
def get_config(self) -> ElasticsearchStoreConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
|
@@ -1,46 +0,0 @@
|
||||
"""Vector store factory."""
|
||||
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt_ext.storage import __vector_store__ as vector_store_list
|
||||
from dbgpt_ext.storage import _select_rag_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStoreFactory:
|
||||
"""Factory for vector store."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
vector_store_type: str, vector_space_name: str, vector_store_configure=None
|
||||
) -> VectorStoreBase:
|
||||
"""Create a VectorStore instance.
|
||||
|
||||
Args:
|
||||
- vector_store_type: vector store type Chroma, Milvus, etc.
|
||||
- vector_store_config: vector store config
|
||||
"""
|
||||
store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type)
|
||||
|
||||
try:
|
||||
config = cfg_cls()
|
||||
if vector_store_configure:
|
||||
vector_store_configure(vector_space_name, config)
|
||||
return store_cls(config)
|
||||
except Exception as e:
|
||||
logger.error("create vector store failed: %s", e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
|
||||
for t in vector_store_list:
|
||||
if t.lower() == vector_store_type.lower():
|
||||
store_cls, cfg_cls = _select_rag_storage(t)
|
||||
if issubclass(store_cls, VectorStoreBase) and issubclass(
|
||||
cfg_cls, VectorStoreConfig
|
||||
):
|
||||
return store_cls, cfg_cls
|
||||
raise Exception(f"Vector store {vector_store_type} not supported")
|
@@ -5,9 +5,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
@@ -94,51 +94,82 @@ logger = logging.getLogger(__name__)
|
||||
],
|
||||
description=_("Milvus vector config."),
|
||||
)
|
||||
@dataclass
|
||||
class MilvusVectorConfig(VectorStoreConfig):
|
||||
"""Milvus vector store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
__type__ = "Milvus"
|
||||
|
||||
uri: Optional[str] = Field(
|
||||
uri: str = field(
|
||||
default=None,
|
||||
description="The uri of milvus store, if not set, will use the default uri.",
|
||||
metadata={
|
||||
"help": _("The uri of milvus store, if not set, will use the default uri.")
|
||||
},
|
||||
)
|
||||
port: str = Field(
|
||||
port: str = field(
|
||||
default="19530",
|
||||
description="The port of milvus store, if not set, will use the default port.",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The port of milvus store, if not set, will use the default port."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
alias: str = Field(
|
||||
alias: str = field(
|
||||
default="default",
|
||||
description="The alias of milvus store, if not set, will use the default "
|
||||
"alias.",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The alias of milvus store, if not set, will use the default alias."
|
||||
)
|
||||
},
|
||||
)
|
||||
primary_field: str = Field(
|
||||
primary_field: str = field(
|
||||
default="pk_id",
|
||||
description="The primary field of milvus store, if not set, will use the "
|
||||
"default primary field.",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The primary field of milvus store, i"
|
||||
"f not set, will use the default primary field."
|
||||
)
|
||||
},
|
||||
)
|
||||
text_field: str = Field(
|
||||
text_field: str = field(
|
||||
default="content",
|
||||
description="The text field of milvus store, if not set, will use the default "
|
||||
"text field.",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The text field of milvus store, if not set, will use the "
|
||||
"default text field."
|
||||
)
|
||||
},
|
||||
)
|
||||
embedding_field: str = Field(
|
||||
embedding_field: str = field(
|
||||
default="vector",
|
||||
description="The embedding field of milvus store, if not set, will use the "
|
||||
"default embedding field.",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The embedding field of milvus store, if not set, will use the "
|
||||
"default embedding field."
|
||||
)
|
||||
},
|
||||
)
|
||||
metadata_field: str = Field(
|
||||
metadata_field: str = field(
|
||||
default="metadata",
|
||||
description="The metadata field of milvus store, if not set, will use the "
|
||||
"default metadata field.",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The metadata field of milvus store, if not set, will use the "
|
||||
"default metadata field."
|
||||
)
|
||||
},
|
||||
)
|
||||
secure: str = Field(
|
||||
secure: str = field(
|
||||
default="",
|
||||
description="The secure of milvus store, if not set, will use the default "
|
||||
"secure.",
|
||||
metadata={
|
||||
"help": _("The secure of milvus store, if not set, will use the default ")
|
||||
},
|
||||
)
|
||||
|
||||
def create_store(self, **kwargs) -> "MilvusStore":
|
||||
"""Create Milvus Store."""
|
||||
return MilvusStore(vector_store_config=self, **kwargs)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Milvus Vector Store"),
|
||||
@@ -159,7 +190,12 @@ class MilvusVectorConfig(VectorStoreConfig):
|
||||
class MilvusStore(VectorStoreBase):
|
||||
"""Milvus vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_config: MilvusVectorConfig,
|
||||
name: Optional[str],
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""Create a MilvusStore instance.
|
||||
|
||||
Args:
|
||||
@@ -190,18 +226,16 @@ class MilvusStore(VectorStoreBase):
|
||||
)
|
||||
self.secure = milvus_vector_config.get("secure") or os.getenv("MILVUS_SECURE")
|
||||
|
||||
self.collection_name = (
|
||||
milvus_vector_config.get("name") or vector_store_config.name
|
||||
)
|
||||
self.collection_name = name
|
||||
if string_utils.contains_chinese(self.collection_name):
|
||||
bytes_str = self.collection_name.encode("utf-8")
|
||||
hex_str = bytes_str.hex()
|
||||
self.collection_name = hex_str
|
||||
if vector_store_config.embedding_fn is None:
|
||||
if embedding_fn is None:
|
||||
# Perform runtime checks on self.embedding to
|
||||
# ensure it has been correctly set and loaded
|
||||
raise ValueError("embedding_fn is required for MilvusStore")
|
||||
self.embedding: Embeddings = vector_store_config.embedding_fn
|
||||
self.embedding: Embeddings = embedding_fn
|
||||
self.fields: List = []
|
||||
self.alias = milvus_vector_config.get("alias") or "default"
|
||||
|
||||
|
@@ -5,14 +5,14 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sqlalchemy import JSON, Column, String, Table, func, text
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
_COMMON_PARAMETERS,
|
||||
@@ -122,35 +122,48 @@ def _normalize(vector: List[float]) -> List[float]:
|
||||
],
|
||||
description="OceanBase vector store config.",
|
||||
)
|
||||
@dataclass
|
||||
class OceanBaseConfig(VectorStoreConfig):
|
||||
"""OceanBase vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
__type__ = "OceanBase"
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
ob_host: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The host of oceanbase, if not set, will use the default host."
|
||||
},
|
||||
)
|
||||
ob_port: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The port of oceanbase, if not set, will use the default port."
|
||||
},
|
||||
)
|
||||
ob_user: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The user of oceanbase, if not set, will use the default user."
|
||||
},
|
||||
)
|
||||
ob_password: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The password of oceanbase, if not set, "
|
||||
"will use the default password"
|
||||
},
|
||||
)
|
||||
ob_database: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The database for vector tables, if not set, "
|
||||
"will use the default database."
|
||||
},
|
||||
)
|
||||
|
||||
"""OceanBase config"""
|
||||
ob_host: Optional[str] = Field(
|
||||
default=None,
|
||||
description="oceanbase host",
|
||||
)
|
||||
ob_port: Optional[int] = Field(
|
||||
default=None,
|
||||
description="oceanbase port",
|
||||
)
|
||||
ob_user: Optional[str] = Field(
|
||||
default=None,
|
||||
description="user to login",
|
||||
)
|
||||
ob_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="password to login",
|
||||
)
|
||||
ob_database: Optional[str] = Field(
|
||||
default=None,
|
||||
description="database for vector tables",
|
||||
)
|
||||
def create_store(self, **kwargs) -> "OceanBaseStore":
|
||||
"""Create OceanBase store."""
|
||||
return OceanBaseStore(vector_store_config=self, **kwargs)
|
||||
|
||||
|
||||
@register_resource(
|
||||
@@ -172,7 +185,12 @@ class OceanBaseConfig(VectorStoreConfig):
|
||||
class OceanBaseStore(VectorStoreBase):
|
||||
"""OceanBase vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: OceanBaseConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_config: OceanBaseConfig,
|
||||
name: Optional[str],
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""Create a OceanBaseStore instance."""
|
||||
try:
|
||||
from pyobvector import ObVecClient # type: ignore
|
||||
@@ -188,8 +206,8 @@ class OceanBaseStore(VectorStoreBase):
|
||||
super().__init__()
|
||||
|
||||
self._vector_store_config = vector_store_config
|
||||
self.embedding_function = vector_store_config.embedding_fn
|
||||
self.table_name = vector_store_config.name
|
||||
self.embedding_function = embedding_fn
|
||||
self.table_name = name
|
||||
|
||||
vector_store_config_map = vector_store_config.to_dict()
|
||||
OB_HOST = str(
|
||||
|
@@ -1,10 +1,10 @@
|
||||
"""Postgres vector store."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
_COMMON_PARAMETERS,
|
||||
@@ -37,17 +37,22 @@ logger = logging.getLogger(__name__)
|
||||
],
|
||||
description="PG vector config.",
|
||||
)
|
||||
@dataclass
|
||||
class PGVectorConfig(VectorStoreConfig):
|
||||
"""PG vector store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
connection_string: str = Field(
|
||||
connection_string: str = field(
|
||||
default=None,
|
||||
description="the connection string of vector store, if not set, will use the "
|
||||
"default connection string.",
|
||||
metadata={
|
||||
"description": "the connection string of vector store, "
|
||||
"if not set, will use the default connection string."
|
||||
},
|
||||
)
|
||||
|
||||
def create_store(self, **kwargs) -> "PGVectorStore":
|
||||
"""Create Milvus Store."""
|
||||
return PGVectorStore(vector_store_config=self, **kwargs)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("PG Vector Store"),
|
||||
@@ -71,7 +76,12 @@ class PGVectorStore(VectorStoreBase):
|
||||
To use this, you should have the ``pgvector`` python package installed.
|
||||
"""
|
||||
|
||||
def __init__(self, vector_store_config: PGVectorConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_config: PGVectorConfig,
|
||||
name: Optional[str],
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""Create a PGVectorStore instance."""
|
||||
try:
|
||||
from langchain.vectorstores import PGVector # mypy: ignore
|
||||
@@ -83,8 +93,8 @@ class PGVectorStore(VectorStoreBase):
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.connection_string = vector_store_config.connection_string
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
self.embeddings = embedding_fn
|
||||
self.collection_name = name
|
||||
|
||||
self.vector_store_client = PGVector(
|
||||
embedding_function=self.embeddings, # type: ignore
|
||||
|
@@ -2,10 +2,10 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
_COMMON_PARAMETERS,
|
||||
@@ -45,20 +45,28 @@ logger = logging.getLogger(__name__)
|
||||
),
|
||||
],
|
||||
)
|
||||
@dataclass
|
||||
class WeaviateVectorConfig(VectorStoreConfig):
|
||||
"""Weaviate vector store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
weaviate_url: str = Field(
|
||||
weaviate_url: str = field(
|
||||
default=os.getenv("WEAVIATE_URL", None),
|
||||
description="weaviate url address, if not set, will use the default url.",
|
||||
metadata={
|
||||
"description": "weaviate url address, if not set, "
|
||||
"will use the default url.",
|
||||
},
|
||||
)
|
||||
persist_path: str = Field(
|
||||
persist_path: str = field(
|
||||
default=os.getenv("WEAVIATE_PERSIST_PATH", None),
|
||||
description="weaviate persist path.",
|
||||
metadata={
|
||||
"description": "weaviate persist path.",
|
||||
},
|
||||
)
|
||||
|
||||
def create_store(self, **kwargs) -> "WeaviateStore":
|
||||
"""Create Weaviate Store store."""
|
||||
return WeaviateStore(vector_store_config=self, **kwargs)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Weaviate Vector Store"),
|
||||
@@ -79,7 +87,12 @@ class WeaviateVectorConfig(VectorStoreConfig):
|
||||
class WeaviateStore(VectorStoreBase):
|
||||
"""Weaviate database."""
|
||||
|
||||
def __init__(self, vector_store_config: WeaviateVectorConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_config: WeaviateVectorConfig,
|
||||
name: Optional[str],
|
||||
embedding_fn: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""Initialize with Weaviate client."""
|
||||
try:
|
||||
import weaviate
|
||||
@@ -92,10 +105,10 @@ class WeaviateStore(VectorStoreBase):
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.weaviate_url = vector_store_config.weaviate_url
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.vector_name = vector_store_config.name
|
||||
self.embedding = embedding_fn
|
||||
self.vector_name = name
|
||||
self.persist_dir = os.path.join(
|
||||
vector_store_config.persist_path, vector_store_config.name + ".vectordb"
|
||||
vector_store_config.persist_path, name + ".vectordb"
|
||||
)
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
Reference in New Issue
Block a user