mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 00:03:29 +00:00
feat: ES VectorStore (#1500)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
8b88f7e11c
commit
db4d318a5f
@ -163,6 +163,13 @@ VECTOR_STORE_TYPE=Chroma
|
||||
#VECTOR_STORE_TYPE=Weaviate
|
||||
#WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network
|
||||
|
||||
## ElasticSearch vector db config
|
||||
#VECTOR_STORE_TYPE=ElasticSearch
|
||||
ElasticSearch_URL=127.0.0.1
|
||||
ElasticSearch_PORT=9200
|
||||
ElasticSearch_USERNAME=elastic
|
||||
ElasticSearch_PASSWORD=i=+iLw9y0Jduq86XTi6W
|
||||
|
||||
#*******************************************************************#
|
||||
#** WebServer Language Support **#
|
||||
#*******************************************************************#
|
||||
|
@ -59,6 +59,9 @@ ignore_missing_imports = True
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-jieba.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# Storage
|
||||
[mypy-msgpack.*]
|
||||
ignore_missing_imports = True
|
||||
@ -72,6 +75,9 @@ ignore_missing_imports = True
|
||||
[mypy-pymilvus.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-elasticsearch.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-cryptography.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
@ -207,7 +207,7 @@ class Config(metaclass=Singleton):
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
###dbgpt meta info database connection configuration
|
||||
### dbgpt meta info database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
|
||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
|
||||
@ -247,6 +247,11 @@ class Config(metaclass=Singleton):
|
||||
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
|
||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||
# Elasticsearch Vector Configuration
|
||||
self.ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "127.0.0.1")
|
||||
self.ELASTICSEARCH_PORT = os.getenv("ELASTICSEARCH_PORT", "9200")
|
||||
self.ELASTICSEARCH_USERNAME = os.getenv("ELASTICSEARCH_USERNAME", None)
|
||||
self.ELASTICSEARCH_PASSWORD = os.getenv("ELASTICSEARCH_PASSWORD", None)
|
||||
|
||||
## OceanBase Configuration
|
||||
self.OB_HOST = os.getenv("OB_HOST", "127.0.0.1")
|
||||
|
@ -482,7 +482,13 @@ class KnowledgeService:
|
||||
raise Exception(f"there are no or more than one document called {doc_name}")
|
||||
vector_ids = documents[0].vector_ids
|
||||
if vector_ids is not None:
|
||||
config = VectorStoreConfig(name=space_name)
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
|
@ -32,6 +32,12 @@ def _import_oceanbase() -> Any:
|
||||
return OceanBaseStore
|
||||
|
||||
|
||||
def _import_elastic() -> Any:
|
||||
from dbgpt.storage.vector_store.elastic_store import ElasticStore
|
||||
|
||||
return ElasticStore
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "Chroma":
|
||||
return _import_chroma()
|
||||
@ -43,8 +49,10 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_pgvector()
|
||||
elif name == "OceanBase":
|
||||
return _import_oceanbase()
|
||||
elif name == "ElasticSearch":
|
||||
return _import_elastic()
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
|
||||
__all__ = ["Chroma", "Milvus", "Weaviate", "OceanBase", "PGVector"]
|
||||
__all__ = ["Chroma", "Milvus", "Weaviate", "OceanBase", "PGVector", "ElasticSearch"]
|
||||
|
398
dbgpt/storage/vector_store/elastic_store.py
Normal file
398
dbgpt/storage/vector_store/elastic_store.py
Normal file
@ -0,0 +1,398 @@
|
||||
"""Elasticsearch vector store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
_COMMON_PARAMETERS,
|
||||
VectorStoreBase,
|
||||
VectorStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util import string_utils
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("ElasticSearch Vector Store"),
|
||||
"elasticsearch_vector_store",
|
||||
category=ResourceCategory.VECTOR_STORE,
|
||||
parameters=[
|
||||
*_COMMON_PARAMETERS,
|
||||
Parameter.build_from(
|
||||
_("Uri"),
|
||||
"uri",
|
||||
str,
|
||||
description=_(
|
||||
"The uri of elasticsearch store, if not set, will use the default "
|
||||
"uri."
|
||||
),
|
||||
optional=True,
|
||||
default="localhost",
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Port"),
|
||||
"port",
|
||||
str,
|
||||
description=_(
|
||||
"The port of elasticsearch store, if not set, will use the default "
|
||||
"port."
|
||||
),
|
||||
optional=True,
|
||||
default="9200",
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Alias"),
|
||||
"alias",
|
||||
str,
|
||||
description=_(
|
||||
"The alias of elasticsearch store, if not set, will use the default "
|
||||
"alias."
|
||||
),
|
||||
optional=True,
|
||||
default="default",
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Index Name"),
|
||||
"index_name",
|
||||
str,
|
||||
description=_(
|
||||
"The index name of elasticsearch store, if not set, will use the "
|
||||
"default index name."
|
||||
),
|
||||
optional=True,
|
||||
default="index_name_test",
|
||||
),
|
||||
],
|
||||
description=_("Elasticsearch vector store."),
|
||||
)
|
||||
class ElasticsearchVectorConfig(VectorStoreConfig):
|
||||
"""Elasticsearch vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
uri: str = Field(
|
||||
default="localhost",
|
||||
description="The uri of elasticsearch store, if not set, will use the default "
|
||||
"uri.",
|
||||
)
|
||||
port: str = Field(
|
||||
default="9200",
|
||||
description="The port of elasticsearch store, if not set, will use the default "
|
||||
"port.",
|
||||
)
|
||||
|
||||
alias: str = Field(
|
||||
default="default",
|
||||
description="The alias of elasticsearch store, if not set, will use the "
|
||||
"default "
|
||||
"alias.",
|
||||
)
|
||||
index_name: str = Field(
|
||||
default="index_name_test",
|
||||
description="The index name of elasticsearch store, if not set, will use the "
|
||||
"default index name.",
|
||||
)
|
||||
metadata_field: str = Field(
|
||||
default="metadata",
|
||||
description="The metadata field of elasticsearch store, if not set, will use "
|
||||
"the default metadata field.",
|
||||
)
|
||||
secure: str = Field(
|
||||
default="",
|
||||
description="The secure of elasticsearch store, if not set, will use the "
|
||||
"default secure.",
|
||||
)
|
||||
|
||||
|
||||
class ElasticStore(VectorStoreBase):
|
||||
"""Elasticsearch vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None:
|
||||
"""Create a ElasticsearchStore instance.
|
||||
|
||||
Args:
|
||||
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
||||
"""
|
||||
connect_kwargs = {}
|
||||
elasticsearch_vector_config = vector_store_config.dict()
|
||||
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
|
||||
"ELASTICSEARCH_URL", "localhost"
|
||||
)
|
||||
self.port = elasticsearch_vector_config.get("post") or os.getenv(
|
||||
"ELASTICSEARCH_PORT", "9200"
|
||||
)
|
||||
self.username = elasticsearch_vector_config.get("username") or os.getenv(
|
||||
"ELASTICSEARCH_USERNAME"
|
||||
)
|
||||
self.password = elasticsearch_vector_config.get("password") or os.getenv(
|
||||
"ELASTICSEARCH_PASSWORD"
|
||||
)
|
||||
|
||||
self.collection_name = (
|
||||
elasticsearch_vector_config.get("name") or vector_store_config.name
|
||||
)
|
||||
# name to hex
|
||||
if string_utils.is_all_chinese(self.collection_name):
|
||||
bytes_str = self.collection_name.encode("utf-8")
|
||||
hex_str = bytes_str.hex()
|
||||
self.collection_name = hex_str
|
||||
if vector_store_config.embedding_fn is None:
|
||||
# Perform runtime checks on self.embedding to
|
||||
# ensure it has been correctly set and loaded
|
||||
raise ValueError("embedding_fn is required for ElasticSearchStore")
|
||||
# to lower case
|
||||
self.index_name = self.collection_name.lower()
|
||||
self.embedding: Embeddings = vector_store_config.embedding_fn
|
||||
self.fields: List = []
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
raise ValueError(
|
||||
"Both username and password must be set to use authentication for "
|
||||
"ElasticSearch"
|
||||
)
|
||||
|
||||
if self.username:
|
||||
connect_kwargs["username"] = self.username
|
||||
connect_kwargs["password"] = self.password
|
||||
|
||||
# english index settings
|
||||
self.index_settings = {
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0, # replica number
|
||||
}
|
||||
}
|
||||
""""""
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import langchain and elasticsearch python package. "
|
||||
"Please install it with `pip install langchain` and "
|
||||
"`pip install elasticsearch`."
|
||||
)
|
||||
try:
|
||||
if self.username != "" and self.password != "":
|
||||
self.es_client_python = Elasticsearch(
|
||||
f"http://{self.uri}:{self.port}",
|
||||
basic_auth=(self.username, self.password),
|
||||
)
|
||||
# create es index
|
||||
if not self.vector_name_exists():
|
||||
self.es_client_python.indices.create(
|
||||
index=self.index_name, body=self.index_settings
|
||||
)
|
||||
else:
|
||||
logger.warning("ElasticSearch not set username and password")
|
||||
self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}")
|
||||
if not self.vector_name_exists():
|
||||
self.es_client_python.indices.create(
|
||||
index=self.index_name, body=self.index_settings
|
||||
)
|
||||
except ConnectionError:
|
||||
logger.error("ElasticSearch connection failed")
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch connection failed : {e}")
|
||||
|
||||
# create es index
|
||||
try:
|
||||
if self.username != "" and self.password != "":
|
||||
self.db_init = ElasticsearchStore(
|
||||
es_url=f"http://{self.uri}:{self.port}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embedding, # type: ignore
|
||||
es_user=self.username,
|
||||
es_password=self.password,
|
||||
)
|
||||
else:
|
||||
logger.warning("ElasticSearch not set username and password")
|
||||
self.db_init = ElasticsearchStore(
|
||||
es_url=f"http://{self.uri}:{self.port}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embedding, # type: ignore
|
||||
)
|
||||
except ConnectionError:
|
||||
logger.error("ElasticSearch connection failed")
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch connection failed: {e}")
|
||||
|
||||
def load_document(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
) -> List[str]:
|
||||
"""Add text data into ElasticSearch."""
|
||||
logger.info("ElasticStore load document")
|
||||
try:
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import langchain python package. "
|
||||
"Please install it with `pip install langchain` and "
|
||||
"`pip install elasticsearch`."
|
||||
)
|
||||
try:
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
metadatas = [chunk.metadata for chunk in chunks]
|
||||
ids = [chunk.chunk_id for chunk in chunks]
|
||||
if self.username != "" and self.password != "":
|
||||
self.db = ElasticsearchStore.from_texts(
|
||||
texts=texts,
|
||||
embedding=self.embedding, # type: ignore
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
es_url=f"http://{self.uri}:{self.port}",
|
||||
index_name=self.index_name,
|
||||
# Defaults to COSINE. Can be one of COSINE, EUCLIDEAN_DISTANCE
|
||||
# , or DOT_PRODUCT.
|
||||
distance_strategy="COSINE",
|
||||
# Name of the field to store the texts in.
|
||||
query_field="context",
|
||||
# Optional. Name of the field to store the embedding vectors in.
|
||||
vector_query_field="dense_vector",
|
||||
# verify_certs=False,
|
||||
# strategy: Optional. Retrieval strategy to use when searching the
|
||||
# index.
|
||||
# Defaults to ApproxRetrievalStrategy.
|
||||
# Can be one of ExactRetrievalStrategy, ApproxRetrievalStrategy,
|
||||
# or SparseRetrievalStrategy.
|
||||
es_user=self.username,
|
||||
es_password=self.password,
|
||||
) # type: ignore
|
||||
logger.info("Elasticsearch save success.......")
|
||||
return ids
|
||||
else:
|
||||
self.db = ElasticsearchStore.from_documents(
|
||||
texts=texts,
|
||||
embedding=self.embedding, # type: ignore
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
es_url=f"http://{self.uri}:{self.port}",
|
||||
index_name=self.index_name,
|
||||
distance_strategy="COSINE",
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
# verify_certs=False,
|
||||
) # type: ignore
|
||||
return ids
|
||||
except ConnectionError as ce:
|
||||
logger.error(f"ElasticSearch connect failed {ce}")
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch load_document failed : {e}")
|
||||
return []
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""Delete vector by ids."""
|
||||
logger.info(f"begin delete elasticsearch len ids: {len(ids)}")
|
||||
ids = ids.split(",")
|
||||
try:
|
||||
self.db_init.delete(ids=ids)
|
||||
self.es_client_python.indices.refresh(index=self.index_name)
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch delete_by_ids failed : {e}")
|
||||
|
||||
def similar_search(
|
||||
self,
|
||||
text: str,
|
||||
topk: int,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results."""
|
||||
info_docs = self._search(query=text, topk=topk, filters=filters)
|
||||
return info_docs
|
||||
|
||||
def similar_search_with_scores(
|
||||
self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the
|
||||
ElasticSearch documentation found here: https://www.elastic.co/.
|
||||
|
||||
Args:
|
||||
text (str): The query text.
|
||||
topk (int): The number of similar documents to return.
|
||||
score_threshold (float): Optional, a floating point value between 0 to 1.
|
||||
filters (Optional[MetadataFilters]): Optional, metadata filters.
|
||||
Returns:
|
||||
List[Chunk]: Result doc and score.
|
||||
"""
|
||||
query = text
|
||||
info_docs = self._search(query=query, topk=topk, filters=filters)
|
||||
docs_and_scores = [
|
||||
chunk for chunk in info_docs if chunk.score >= score_threshold
|
||||
]
|
||||
if len(docs_and_scores) == 0:
|
||||
logger.warning(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
return docs_and_scores
|
||||
|
||||
def _search(
|
||||
self, query: str, topk: int, filters: Optional[MetadataFilters] = None, **kwargs
|
||||
) -> List[Chunk]:
|
||||
"""Search similar documents.
|
||||
|
||||
Args:
|
||||
query: query text
|
||||
topk: return docs nums. Defaults to 4.
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
jieba_tokenize = kwargs.pop("jieba_tokenize", None)
|
||||
if jieba_tokenize:
|
||||
try:
|
||||
import jieba
|
||||
import jieba.analyse
|
||||
except ImportError:
|
||||
raise ValueError("Please install it with `pip install jieba`.")
|
||||
query_list = jieba.analyse.textrank(query, topK=20, withWeight=False)
|
||||
query = " ".join(query_list)
|
||||
body = {"query": {"match": {"context": query}}}
|
||||
search_results = self.es_client_python.search(
|
||||
index=self.index_name, body=body, size=topk
|
||||
)
|
||||
search_results = search_results["hits"]["hits"]
|
||||
|
||||
if not search_results:
|
||||
logger.warning("""No ElasticSearch results found.""")
|
||||
return []
|
||||
info_docs = []
|
||||
for result in search_results:
|
||||
doc_id = result["_id"]
|
||||
source = result["_source"]
|
||||
context = source["context"]
|
||||
metadata = source["metadata"]
|
||||
score = result["_score"]
|
||||
doc_with_score = Chunk(
|
||||
content=context, metadata=metadata, score=score, chunk_id=doc_id
|
||||
)
|
||||
info_docs.append(doc_with_score)
|
||||
return info_docs
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""Whether vector name exists."""
|
||||
return self.es_client_python.indices.exists(index=self.index_name)
|
||||
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector name/index_name."""
|
||||
if self.es_client_python.indices.exists(index=self.index_name):
|
||||
self.es_client_python.indices.delete(index=self.index_name)
|
Loading…
Reference in New Issue
Block a user