feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-06-13 13:49:17 +08:00
committed by GitHub
parent 162e2c9b1c
commit 58d08780d6
86 changed files with 948 additions and 440 deletions

View File

@@ -0,0 +1 @@
"""Full Text Store Module."""

View File

@@ -0,0 +1,69 @@
"""Full text store base class."""
import logging
from abc import abstractmethod
from concurrent.futures import Executor
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async
logger = logging.getLogger(__name__)
class FullTextStoreBase(IndexStoreBase):
"""Graph store base class."""
def __init__(self, executor: Optional[Executor] = None):
"""Initialize vector store."""
super().__init__(executor)
@abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Async load document in index database.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
return await blocking_func_to_async(self._executor, self.load_document, chunks)
@abstractmethod
def similar_search_with_scores(
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar search with scores in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1
filters(Optional[MetadataFilters]): metadata filters.
"""
@abstractmethod
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete docs.
Args:
ids(str): The vector ids to delete, separated by comma.
"""
def delete_vector_name(self, index_name: str):
"""Delete name."""

View File

@@ -0,0 +1,227 @@
"""Elasticsearch document store."""
import json
import os
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.rag.index.base import logger
from dbgpt.storage.full_text.base import FullTextStoreBase
from dbgpt.storage.vector_store.elastic_store import ElasticsearchVectorConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util import string_utils
from dbgpt.util.executor_utils import blocking_func_to_async
class ElasticDocumentConfig(ElasticsearchVectorConfig):
"""Elasticsearch document store config."""
k1: Optional[float] = 2.0
b: Optional[float] = 0.75
class ElasticDocumentStore(FullTextStoreBase):
"""Elasticsearch index store."""
def __init__(
self, es_config: ElasticDocumentConfig, executor: Optional[Executor] = None
):
"""Init elasticsearch index store.
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index-
modules-similarity.html
TF/IDF based similarity that has built-in tf normalization and is supposed to
work better for short fields (like names). See Okapi_BM25 for more details.
This similarity has the following options:
"""
super().__init__()
from elasticsearch import Elasticsearch
self._es_config = es_config
self._es_url = es_config.uri or os.getenv("ELASTICSEARCH_URL", "localhost")
self._es_port = es_config.port or os.getenv("ELASTICSEARCH_PORT", "9200")
self._es_username = es_config.user or os.getenv("ELASTICSEARCH_USER", "elastic")
self._es_password = es_config.password or os.getenv(
"ELASTICSEARCH_PASSWORD", "dbgpt"
)
self._index_name = es_config.name.lower()
if string_utils.contains_chinese(es_config.name):
bytes_str = es_config.name.encode("utf-8")
hex_str = bytes_str.hex()
self._index_name = "dbgpt_" + hex_str
# k1 (Optional[float]): Controls non-linear term frequency normalization
# (saturation). The default value is 2.0.
self._k1 = es_config.k1 or 2.0
# b (Optional[float]): Controls to what degree document length normalizes
# tf values. The default value is 0.75.
self._b = es_config.b or 0.75
if self._es_username and self._es_password:
self._es_client = Elasticsearch(
hosts=[f"http://{self._es_url}:{self._es_port}"],
basic_auth=(self._es_username, self._es_password),
)
else:
self._es_client = Elasticsearch(
hosts=[f"http://{self._es_url}:{self._es_port}"],
)
self._es_index_settings = {
"analysis": {"analyzer": {"default": {"type": "standard"}}},
"similarity": {
"custom_bm25": {
"type": "BM25",
"k1": self._k1,
"b": self._b,
}
},
}
self._es_mappings = {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25",
},
"metadata": {
"type": "keyword",
},
}
}
if not self._es_client.indices.exists(index=self._index_name):
self._es_client.indices.create(
index=self._index_name,
mappings=self._es_mappings,
settings=self._es_index_settings,
)
self._executor = executor or ThreadPoolExecutor()
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in elasticsearch.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
try:
from elasticsearch.helpers import bulk
except ImportError:
raise ValueError("Please install package `pip install elasticsearch`.")
es_requests = []
ids = []
contents = [chunk.content for chunk in chunks]
metadatas = [json.dumps(chunk.metadata) for chunk in chunks]
chunk_ids = [chunk.chunk_id for chunk in chunks]
for i, content in enumerate(contents):
es_request = {
"_op_type": "index",
"_index": self._index_name,
"content": content,
"metadata": metadatas[i],
"_id": chunk_ids[i],
}
ids.append(chunk_ids[i])
es_requests.append(es_request)
bulk(self._es_client, es_requests)
self._es_client.indices.refresh(index=self._index_name)
return ids
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar text.
Args:
text(str): text.
topk(int): topk.
filters(MetadataFilters): filters.
Return:
List[Chunk]: similar text.
"""
es_query = {"query": {"match": {"content": text}}}
res = self._es_client.search(index=self._index_name, body=es_query)
chunks = []
for r in res["hits"]["hits"]:
chunks.append(
Chunk(
chunk_id=r["_id"],
content=r["_source"]["content"],
metadata=json.loads(r["_source"]["metadata"]),
)
)
return chunks[:topk]
def similar_search_with_scores(
self,
text,
top_k: int = 10,
score_threshold: float = 0.3,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Search similar text with scores.
Args:
text(str): text.
top_k(int): top k.
min_score(float): min score.
filters(MetadataFilters): filters.
Return:
List[Tuple[str, float]]: similar text with scores.
"""
es_query = {"query": {"match": {"content": text}}}
res = self._es_client.search(index=self._index_name, body=es_query)
chunks_with_scores = []
for r in res["hits"]["hits"]:
if r["_score"] >= score_threshold:
chunks_with_scores.append(
Chunk(
chunk_id=r["_id"],
content=r["_source"]["content"],
metadata=json.loads(r["_source"]["metadata"]),
score=r["_score"],
)
)
if score_threshold is not None and len(chunks_with_scores) == 0:
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return chunks_with_scores[:top_k]
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Async load document in elasticsearch.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
return await blocking_func_to_async(self._executor, self.load_document, chunks)
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete document by ids.
Args:
ids(List[str]): document ids.
Return:
return ids.
"""
id_list = ids.split(",")
bulk_body = [
{"delete": {"_index": self._index_name, "_id": doc_id}}
for doc_id in id_list
]
self._es_client.bulk(body=bulk_body)
return id_list
def delete_vector_name(self, index_name: str):
"""Delete index by name.
Args:
index_name(str): The name of index to delete.
"""
self._es_client.indices.delete(index=self._index_name)

View File

@@ -0,0 +1,63 @@
"""OpenSearch index store."""
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.storage.vector_store.filters import MetadataFilters
class OpenSearch(IndexStoreBase):
"""OpenSearch index store."""
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
pass
def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Async load document in index database.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
pass
def similar_search_with_scores(
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar search with scores in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
pass
def delete_by_ids(self, ids: str):
"""Delete docs.
Args:
ids(str): The vector ids to delete, separated by comma.
"""
pass
def delete_vector_name(self, index_name: str):
"""Delete name."""
pass

View File

@@ -1,7 +1,7 @@
"""Knowledge graph base class."""
import logging
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
@@ -23,6 +23,6 @@ class KnowledgeGraphBase(IndexStoreBase, ABC):
def query_graph(self, limit: Optional[int] = None) -> Graph:
"""Get graph data."""
def delete_by_ids(self, ids: str):
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete document by ids."""
raise Exception("Delete document not supported by knowledge graph")

View File

@@ -62,6 +62,15 @@ def _import_openspg() -> Tuple[Type, Type]:
return OpenSPG, OpenSPGConfig
def _import_full_text() -> Tuple[Type, Type]:
from dbgpt.storage.full_text.elasticsearch import (
ElasticDocumentConfig,
ElasticDocumentStore,
)
return ElasticDocumentStore, ElasticDocumentConfig
def __getattr__(name: str) -> Tuple[Type, Type]:
if name == "Chroma":
return _import_chroma()
@@ -79,6 +88,8 @@ def __getattr__(name: str) -> Tuple[Type, Type]:
return _import_builtin_knowledge_graph()
elif name == "OpenSPG":
return _import_openspg()
elif name == "FullText":
return _import_full_text()
else:
raise AttributeError(f"Could not find: {name}")
@@ -94,4 +105,6 @@ __vector_store__ = [
__knowledge_graph__ = ["KnowledgeGraph", "OpenSPG"]
__all__ = __vector_store__ + __knowledge_graph__
__document_store__ = ["FullText"]
__all__ = __vector_store__ + __knowledge_graph__ + __document_store__

View File

@@ -1,296 +0,0 @@
"""Connector for vector store."""
import copy
import logging
import os
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Type, cast
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import (
FunctionDynamicOptions,
OptionValue,
Parameter,
ResourceCategory,
register_resource,
)
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
connector: Dict[str, Tuple[Type, Type]] = {}
pools: DefaultDict[str, Dict] = defaultdict(dict)
def _load_vector_options() -> List[OptionValue]:
from dbgpt.storage import vector_store
return [
OptionValue(label=cls, name=cls, value=cls)
for cls in vector_store.__all__
if issubclass(getattr(vector_store, cls)[0], IndexStoreBase)
]
@register_resource(
_("Vector Store Connector"),
"vector_store_connector",
category=ResourceCategory.VECTOR_STORE,
parameters=[
Parameter.build_from(
_("Vector Store Type"),
"vector_store_type",
str,
description=_("The type of vector store."),
options=FunctionDynamicOptions(func=_load_vector_options),
),
Parameter.build_from(
_("Vector Store Implementation"),
"vector_store_config",
VectorStoreConfig,
description=_("The vector store implementation."),
optional=True,
default=None,
),
],
)
class VectorStoreConnector:
"""The connector for vector store.
VectorStoreConnector, can connect different vector db provided load document api_v1
and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus,
Weaviate).
2.similar_search: similarity search from vector_store.
3.similar_search_with_scores: similarity search with similarity score from
vector_store
code example:
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> vector_store_config = VectorStoreConfig
>>> vector_store_connector = VectorStoreConnector(vector_store_type="Chroma")
"""
def __init__(
self,
vector_store_type: str,
vector_store_config: Optional[IndexStoreConfig] = None,
) -> None:
"""Create a VectorStoreConnector instance.
Args:
- vector_store_type: vector store type Milvus, Chroma, Weaviate
- ctx: vector store config params.
"""
if vector_store_config is None:
raise Exception("vector_store_config is required")
self._index_store_config = vector_store_config
self._register()
if self._match(vector_store_type):
self.connector_class, self.config_class = connector[vector_store_type]
else:
raise Exception(f"Vector store {vector_store_type} not supported")
logger.info(f"VectorStore:{self.connector_class}")
self._vector_store_type = vector_store_type
self._embeddings = vector_store_config.embedding_fn
config_dict = {}
for key in vector_store_config.to_dict().keys():
value = getattr(vector_store_config, key)
if value is not None:
config_dict[key] = value
for key, value in vector_store_config.model_extra.items():
if value is not None:
config_dict[key] = value
config = self.config_class(**config_dict)
try:
if vector_store_type in pools and config.name in pools[vector_store_type]:
self.client = pools[vector_store_type][config.name]
else:
client = self.connector_class(config)
pools[vector_store_type][config.name] = self.client = client
except Exception as e:
logger.error("connect vector store failed: %s", e)
raise e
@classmethod
def from_default(
cls,
vector_store_type: Optional[str] = None,
embedding_fn: Optional[Any] = None,
vector_store_config: Optional[VectorStoreConfig] = None,
) -> "VectorStoreConnector":
"""Initialize default vector store connector."""
vector_store_type = vector_store_type or os.getenv(
"VECTOR_STORE_TYPE", "Chroma"
)
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
vector_store_config = vector_store_config or ChromaVectorConfig()
vector_store_config.embedding_fn = embedding_fn
real_vector_store_type = cast(str, vector_store_type)
return cls(real_vector_store_type, vector_store_config)
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in vector database.
Args:
- chunks: document chunks.
Return chunk ids.
"""
max_chunks_once_load = (
self._index_store_config.max_chunks_once_load
if self._index_store_config
else 10
)
max_threads = (
self._index_store_config.max_threads if self._index_store_config else 1
)
return self.client.load_document_with_limit(
chunks,
max_chunks_once_load,
max_threads,
)
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Async load document in vector database.
Args:
- chunks: document chunks.
Return chunk ids.
"""
max_chunks_once_load = (
self._index_store_config.max_chunks_once_load
if self._index_store_config
else 10
)
max_threads = (
self._index_store_config.max_threads if self._index_store_config else 1
)
return await self.client.aload_document_with_limit(
chunks, max_chunks_once_load, max_threads
)
def similar_search(
self, doc: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in vector database.
Args:
- doc: query text
- topk: topk
- filters: metadata filters.
Return:
- chunks: chunks.
"""
return self.client.similar_search(doc, topk, filters)
def similar_search_with_scores(
self,
doc: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar_search_with_score in vector database.
Return docs and relevance scores in the range [0, 1].
Args:
doc(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
filters: metadata filters.
Return:
- chunks: Return docs and relevance scores in the range [0, 1].
"""
return self.client.similar_search_with_scores(
doc, topk, score_threshold, filters
)
async def asimilar_search_with_scores(
self,
doc: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Async similar_search_with_score in vector database."""
return await self.client.asimilar_search_with_scores(
doc, topk, score_threshold, filters
)
@property
def vector_store_config(self) -> IndexStoreConfig:
"""Return the vector store config."""
if not self._index_store_config:
raise ValueError("vector store config not set.")
return self._index_store_config
def vector_name_exists(self):
"""Whether vector name exists."""
return self.client.vector_name_exists()
def delete_vector_name(self, vector_name: str):
"""Delete vector name.
Args:
- vector_name: vector store name
"""
try:
if self.vector_name_exists():
self.client.delete_vector_name(vector_name)
except Exception as e:
logger.error(f"delete vector name {vector_name} failed: {e}")
raise Exception(f"delete name {vector_name} failed")
return True
def delete_by_ids(self, ids):
"""Delete vector by ids.
Args:
- ids: vector ids
"""
return self.client.delete_by_ids(ids=ids)
@property
def current_embeddings(self) -> Optional[Embeddings]:
"""Return the current embeddings."""
return self._embeddings
def new_connector(self, name: str, **kwargs) -> "VectorStoreConnector":
"""Create a new connector.
New connector based on the current connector.
"""
config = copy.copy(self.vector_store_config)
for k, v in kwargs.items():
if v is not None:
setattr(config, k, v)
config.name = name
return self.__class__(self._vector_store_type, config)
def _match(self, vector_store_type) -> bool:
return bool(connector.get(vector_store_type))
def _register(self):
from dbgpt.storage import vector_store
for cls in vector_store.__all__:
store_cls, config_cls = getattr(vector_store, cls)
if issubclass(store_cls, IndexStoreBase) and issubclass(
config_cls, IndexStoreConfig
):
connector[cls] = (store_cls, config_cls)

View File

@@ -145,7 +145,7 @@ class ElasticStore(VectorStoreBase):
elasticsearch_vector_config.get("name") or vector_store_config.name
)
# name to hex
if string_utils.is_all_chinese(self.collection_name):
if string_utils.contains_chinese(self.collection_name):
bytes_str = self.collection_name.encode("utf-8")
hex_str = bytes_str.hex()
self.collection_name = hex_str

View File

@@ -174,7 +174,7 @@ class MilvusStore(VectorStoreBase):
self.collection_name = (
milvus_vector_config.get("name") or vector_store_config.name
)
if string_utils.is_all_chinese(self.collection_name):
if string_utils.contains_chinese(self.collection_name):
bytes_str = self.collection_name.encode("utf-8")
hex_str = bytes_str.hex()
self.collection_name = hex_str