mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
feat(ChatKnowledge):add hybrid search for knowledge space. (#2722)
This commit is contained in:
@@ -18,7 +18,9 @@ class RetrieverStrategy(str, Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
|
SEMANTIC = "semantic"
|
||||||
GRAPH = "graph"
|
GRAPH = "graph"
|
||||||
|
Tree = "tree"
|
||||||
KEYWORD = "keyword"
|
KEYWORD = "keyword"
|
||||||
HYBRID = "hybrid"
|
HYBRID = "hybrid"
|
||||||
|
|
||||||
|
@@ -278,11 +278,8 @@ class IndexStoreBase(ABC):
|
|||||||
def is_support_full_text_search(self) -> bool:
|
def is_support_full_text_search(self) -> bool:
|
||||||
"""Support full text search.
|
"""Support full text search.
|
||||||
|
|
||||||
Args:
|
|
||||||
collection_name(str): collection name.
|
|
||||||
Return:
|
Return:
|
||||||
bool: The similar documents.
|
bool: The similar documents.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(
|
logger.warning("Full text search is not supported in this index store.")
|
||||||
"Full text search is not supported in this index store."
|
return False
|
||||||
)
|
|
||||||
|
@@ -226,6 +226,7 @@ class DocTreeRetriever(BaseRetriever):
|
|||||||
rerank: Optional[Ranker] = None,
|
rerank: Optional[Ranker] = None,
|
||||||
keywords_extractor: Optional[ExtractorBase] = None,
|
keywords_extractor: Optional[ExtractorBase] = None,
|
||||||
with_content: bool = False,
|
with_content: bool = False,
|
||||||
|
show_tree: bool = True,
|
||||||
executor: Optional[Executor] = None,
|
executor: Optional[Executor] = None,
|
||||||
):
|
):
|
||||||
"""Create DocTreeRetriever.
|
"""Create DocTreeRetriever.
|
||||||
@@ -248,6 +249,7 @@ class DocTreeRetriever(BaseRetriever):
|
|||||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||||
self._keywords_extractor = keywords_extractor
|
self._keywords_extractor = keywords_extractor
|
||||||
self._with_content = with_content
|
self._with_content = with_content
|
||||||
|
self._show_tree = show_tree
|
||||||
self._tree_indexes = self._initialize_doc_tree(docs)
|
self._tree_indexes = self._initialize_doc_tree(docs)
|
||||||
self._executor = executor or ThreadPoolExecutor()
|
self._executor = executor or ThreadPoolExecutor()
|
||||||
|
|
||||||
@@ -305,6 +307,11 @@ class DocTreeRetriever(BaseRetriever):
|
|||||||
retrieve_node = tree_index.search_keywords(tree_index.root, keyword)
|
retrieve_node = tree_index.search_keywords(tree_index.root, keyword)
|
||||||
if retrieve_node:
|
if retrieve_node:
|
||||||
# If a match is found, return the corresponding chunks
|
# If a match is found, return the corresponding chunks
|
||||||
|
if self._show_tree:
|
||||||
|
logger.info(
|
||||||
|
f"DocTreeIndex Match found in: {retrieve_node.level_text}"
|
||||||
|
)
|
||||||
|
tree_index.display_tree(retrieve_node)
|
||||||
all_nodes.append(retrieve_node)
|
all_nodes.append(retrieve_node)
|
||||||
return all_nodes
|
return all_nodes
|
||||||
|
|
||||||
@@ -335,12 +342,11 @@ class DocTreeRetriever(BaseRetriever):
|
|||||||
for doc in docs:
|
for doc in docs:
|
||||||
tree_index = DocTreeIndex()
|
tree_index = DocTreeIndex()
|
||||||
for chunk in doc.chunks:
|
for chunk in doc.chunks:
|
||||||
if not chunk.metadata.get(TITLE):
|
title = chunk.metadata.get("title") or "title"
|
||||||
continue
|
|
||||||
if not self._with_content:
|
if not self._with_content:
|
||||||
tree_index.add_nodes(
|
tree_index.add_nodes(
|
||||||
node_id=chunk.chunk_id,
|
node_id=chunk.chunk_id,
|
||||||
title=chunk.metadata[TITLE],
|
title=title,
|
||||||
header1=chunk.metadata.get(HEADER1),
|
header1=chunk.metadata.get(HEADER1),
|
||||||
header2=chunk.metadata.get(HEADER2),
|
header2=chunk.metadata.get(HEADER2),
|
||||||
header3=chunk.metadata.get(HEADER3),
|
header3=chunk.metadata.get(HEADER3),
|
||||||
@@ -351,7 +357,7 @@ class DocTreeRetriever(BaseRetriever):
|
|||||||
else:
|
else:
|
||||||
tree_index.add_nodes(
|
tree_index.add_nodes(
|
||||||
node_id=chunk.chunk_id,
|
node_id=chunk.chunk_id,
|
||||||
title=chunk.metadata[TITLE],
|
title=title,
|
||||||
header1=chunk.metadata.get(HEADER1),
|
header1=chunk.metadata.get(HEADER1),
|
||||||
header2=chunk.metadata.get(HEADER2),
|
header2=chunk.metadata.get(HEADER2),
|
||||||
header3=chunk.metadata.get(HEADER3),
|
header3=chunk.metadata.get(HEADER3),
|
||||||
|
@@ -218,6 +218,31 @@ class ChromaStore(VectorStoreBase):
|
|||||||
]
|
]
|
||||||
return self.filter_by_score_threshold(chunks, score_threshold)
|
return self.filter_by_score_threshold(chunks, score_threshold)
|
||||||
|
|
||||||
|
async def afull_text_search(
|
||||||
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Similar search in index database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text(str): The query text.
|
||||||
|
topk(int): The number of similar documents to return.
|
||||||
|
filters(Optional[MetadataFilters]): metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: The similar documents.
|
||||||
|
"""
|
||||||
|
logger.info("ChromaStore do not support full text search")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def is_support_full_text_search(self) -> bool:
|
||||||
|
"""Support full text search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name(str): collection name.
|
||||||
|
Return:
|
||||||
|
bool: is support full texts earch.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def vector_name_exists(self) -> bool:
|
def vector_name_exists(self) -> bool:
|
||||||
"""Whether vector name exists."""
|
"""Whether vector name exists."""
|
||||||
try:
|
try:
|
||||||
|
@@ -5,9 +5,12 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
|
from pymilvus.milvus_client import IndexParams, MilvusClient
|
||||||
|
|
||||||
from dbgpt.core import Chunk, Embeddings
|
from dbgpt.core import Chunk, Embeddings
|
||||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||||
from dbgpt.storage.vector_store.base import (
|
from dbgpt.storage.vector_store.base import (
|
||||||
@@ -212,13 +215,13 @@ class MilvusStore(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
self._vector_store_config = vector_store_config
|
self._vector_store_config = vector_store_config
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
from pymilvus import connections
|
# from pymilvus import connections
|
||||||
except ImportError:
|
# except ImportError:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"Could not import pymilvus python package. "
|
# "Could not import pymilvus python package. "
|
||||||
"Please install it with `pip install pymilvus`."
|
# "Please install it with `pip install pymilvus`."
|
||||||
)
|
# )
|
||||||
connect_kwargs = {}
|
connect_kwargs = {}
|
||||||
milvus_vector_config = vector_store_config.to_dict()
|
milvus_vector_config = vector_store_config.to_dict()
|
||||||
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
||||||
@@ -227,7 +230,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.port = milvus_vector_config.get("post") or os.getenv(
|
self.port = milvus_vector_config.get("post") or os.getenv(
|
||||||
"MILVUS_PORT", "19530"
|
"MILVUS_PORT", "19530"
|
||||||
)
|
)
|
||||||
self.username = milvus_vector_config.get("user") or os.getenv("MILVUS_USERNAME")
|
self.username = milvus_vector_config.get("user", "") or os.getenv(
|
||||||
|
"MILVUS_USERNAME"
|
||||||
|
)
|
||||||
self.password = milvus_vector_config.get("password") or os.getenv(
|
self.password = milvus_vector_config.get("password") or os.getenv(
|
||||||
"MILVUS_PASSWORD"
|
"MILVUS_PASSWORD"
|
||||||
)
|
)
|
||||||
@@ -245,12 +250,12 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.embedding: Embeddings = embedding_fn
|
self.embedding: Embeddings = embedding_fn
|
||||||
self.fields: List = []
|
self.fields: List = []
|
||||||
self.alias = milvus_vector_config.get("alias") or "default"
|
self.alias = milvus_vector_config.get("alias") or "default"
|
||||||
|
self._consistency_level = "Session"
|
||||||
|
|
||||||
# use HNSW by default.
|
# use HNSW by default.
|
||||||
self.index_params = {
|
self.index_params = {
|
||||||
"index_type": "HNSW",
|
"index_type": "HNSW",
|
||||||
"metric_type": "COSINE",
|
"metric_type": "COSINE",
|
||||||
"params": {"M": 8, "efConstruction": 64},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# use HNSW by default.
|
# use HNSW by default.
|
||||||
@@ -269,6 +274,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.primary_field = milvus_vector_config.get("primary_field") or "pk_id"
|
self.primary_field = milvus_vector_config.get("primary_field") or "pk_id"
|
||||||
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
||||||
self.text_field = milvus_vector_config.get("text_field") or "content"
|
self.text_field = milvus_vector_config.get("text_field") or "content"
|
||||||
|
self.sparse_vector = (
|
||||||
|
milvus_vector_config.get("sparse_vector") or "sparse_vector"
|
||||||
|
)
|
||||||
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
||||||
self.props_field = milvus_vector_config.get("props_field") or "props_field"
|
self.props_field = milvus_vector_config.get("props_field") or "props_field"
|
||||||
|
|
||||||
@@ -281,12 +289,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
connect_kwargs["user"] = self.username
|
connect_kwargs["user"] = self.username
|
||||||
connect_kwargs["password"] = self.password
|
connect_kwargs["password"] = self.password
|
||||||
|
|
||||||
connections.connect(
|
url = f"http://{self.uri}:{self.port}"
|
||||||
host=self.uri or "127.0.0.1",
|
self._milvus_client = MilvusClient(
|
||||||
port=self.port or "19530",
|
uri=url, user=self.username, db_name="default"
|
||||||
user=self.username,
|
|
||||||
password=self.password,
|
|
||||||
alias="default",
|
|
||||||
)
|
)
|
||||||
self.col = self.create_collection(collection_name=self.collection_name)
|
self.col = self.create_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
@@ -305,6 +310,8 @@ class MilvusStore(VectorStoreBase):
|
|||||||
CollectionSchema,
|
CollectionSchema,
|
||||||
DataType,
|
DataType,
|
||||||
FieldSchema,
|
FieldSchema,
|
||||||
|
Function,
|
||||||
|
FunctionType,
|
||||||
connections,
|
connections,
|
||||||
utility,
|
utility,
|
||||||
)
|
)
|
||||||
@@ -333,30 +340,57 @@ class MilvusStore(VectorStoreBase):
|
|||||||
vector_field = self.vector_field
|
vector_field = self.vector_field
|
||||||
text_field = self.text_field
|
text_field = self.text_field
|
||||||
metadata_field = self.metadata_field
|
metadata_field = self.metadata_field
|
||||||
|
sparse_vector = self.sparse_vector
|
||||||
props_field = self.props_field
|
props_field = self.props_field
|
||||||
fields = []
|
fields = []
|
||||||
# max_length = 0
|
# max_length = 0
|
||||||
# Create the text field
|
# Create the text field
|
||||||
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
|
fields.append(
|
||||||
|
FieldSchema(
|
||||||
|
text_field,
|
||||||
|
DataType.VARCHAR,
|
||||||
|
max_length=65535,
|
||||||
|
enable_analyzer=self.is_support_full_text_search(),
|
||||||
|
)
|
||||||
|
)
|
||||||
# primary key field
|
# primary key field
|
||||||
fields.append(
|
fields.append(
|
||||||
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
||||||
)
|
)
|
||||||
# vector field
|
# vector field
|
||||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||||
|
if self.is_support_full_text_search():
|
||||||
|
fields.append(FieldSchema(sparse_vector, DataType.SPARSE_FLOAT_VECTOR))
|
||||||
|
|
||||||
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
||||||
fields.append(FieldSchema(props_field, DataType.JSON))
|
fields.append(FieldSchema(props_field, DataType.JSON))
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
|
if self.is_support_full_text_search():
|
||||||
|
bm25_fn = Function(
|
||||||
|
name="text_bm25_emb",
|
||||||
|
input_field_names=[self.text_field],
|
||||||
|
output_field_names=[self.sparse_vector],
|
||||||
|
function_type=FunctionType.BM25,
|
||||||
|
)
|
||||||
|
schema.add_function(bm25_fn)
|
||||||
# Create the collection
|
# Create the collection
|
||||||
collection = Collection(collection_name, schema)
|
collection = Collection(collection_name, schema)
|
||||||
self.col = collection
|
self.col = collection
|
||||||
|
index_params = IndexParams()
|
||||||
# index parameters for the collection
|
# index parameters for the collection
|
||||||
index = self.index_params
|
index_params.add_index(field_name=self.vector_field, **self.index_params)
|
||||||
# milvus index
|
# Create Sparse Vector Index for the collection
|
||||||
collection.create_index(vector_field, index)
|
if self.is_support_full_text_search():
|
||||||
|
collection.create_index(
|
||||||
|
self.sparse_vector,
|
||||||
|
{
|
||||||
|
"index_type": "AUTOINDEX",
|
||||||
|
"metric_type": "BM25",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
collection.create_index(vector_field, self.index_params)
|
||||||
collection.load()
|
collection.load()
|
||||||
return collection
|
return self.col
|
||||||
|
|
||||||
def _load_documents(self, documents) -> List[str]:
|
def _load_documents(self, documents) -> List[str]:
|
||||||
"""Load documents into Milvus.
|
"""Load documents into Milvus.
|
||||||
@@ -418,7 +452,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
insert_dict.setdefault("metadata", []).append(metadata_json)
|
insert_dict.setdefault("metadata", []).append(metadata_json)
|
||||||
insert_dict.setdefault("props_field", []).append(metadata_json)
|
insert_dict.setdefault("props_field", []).append(metadata_json)
|
||||||
# Convert dict to list of lists for insertion
|
# Convert dict to list of lists for insertion
|
||||||
insert_list = [insert_dict[x] for x in self.fields]
|
insert_list = [insert_dict[x] for x in self.fields if self.sparse_vector != x]
|
||||||
# Insert into the collection.
|
# Insert into the collection.
|
||||||
res = self.col.insert(
|
res = self.col.insert(
|
||||||
insert_list, partition_name=partition_name, timeout=timeout
|
insert_list, partition_name=partition_name, timeout=timeout
|
||||||
@@ -570,13 +604,17 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.col.load()
|
self.col.load()
|
||||||
# use default index params.
|
# use default index params.
|
||||||
if param is None:
|
if param is None:
|
||||||
index_type = self.col.indexes[0].params["index_type"]
|
for index in self.col.indexes:
|
||||||
param = self.index_params_map[index_type]
|
if index.params["index_type"] == self.index_params.get("index_type"):
|
||||||
|
param = index.params
|
||||||
|
break
|
||||||
# query text embedding.
|
# query text embedding.
|
||||||
query_vector = self.embedding.embed_query(query)
|
query_vector = self.embedding.embed_query(query)
|
||||||
# Determine result metadata fields.
|
# Determine result metadata fields.
|
||||||
output_fields = self.fields[:]
|
output_fields = self.fields[:]
|
||||||
output_fields.remove(self.vector_field)
|
output_fields.remove(self.vector_field)
|
||||||
|
if self.sparse_vector in output_fields:
|
||||||
|
output_fields.remove(self.sparse_vector)
|
||||||
# milvus search.
|
# milvus search.
|
||||||
res = self.col.search(
|
res = self.col.search(
|
||||||
[query_vector],
|
[query_vector],
|
||||||
@@ -595,7 +633,10 @@ class MilvusStore(VectorStoreBase):
|
|||||||
meta = {x: result.entity.get(x) for x in output_fields}
|
meta = {x: result.entity.get(x) for x in output_fields}
|
||||||
ret.append(
|
ret.append(
|
||||||
(
|
(
|
||||||
Chunk(content=meta.pop(self.text_field), metadata=meta),
|
Chunk(
|
||||||
|
content=meta.pop(self.text_field),
|
||||||
|
metadata=json.loads(meta.pop(self.metadata_field)),
|
||||||
|
),
|
||||||
result.distance,
|
result.distance,
|
||||||
result.id,
|
result.id,
|
||||||
)
|
)
|
||||||
@@ -693,3 +734,51 @@ class MilvusStore(VectorStoreBase):
|
|||||||
utility.drop_collection(self.collection_name)
|
utility.drop_collection(self.collection_name)
|
||||||
|
|
||||||
logger.info(f"truncate milvus collection {self.collection_name} success")
|
logger.info(f"truncate milvus collection {self.collection_name} success")
|
||||||
|
|
||||||
|
def full_text_search(
|
||||||
|
self, text: str, topk: int = 10, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
|
if self.is_support_full_text_search():
|
||||||
|
milvus_filters = self.convert_metadata_filters(filters) if filters else None
|
||||||
|
results = self._milvus_client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
data=[text],
|
||||||
|
anns_field=self.sparse_vector,
|
||||||
|
limit=topk,
|
||||||
|
output_fields=["*"],
|
||||||
|
filter=milvus_filters,
|
||||||
|
)
|
||||||
|
chunk_results = [
|
||||||
|
Chunk(
|
||||||
|
content=r.get("entity").get("content"),
|
||||||
|
chunk_id=str(r.get("pk_id")),
|
||||||
|
score=r.get("distance"),
|
||||||
|
metadata=json.loads(r.get("entity").get("metadata")),
|
||||||
|
retriever="full_text",
|
||||||
|
)
|
||||||
|
for r in results[0]
|
||||||
|
]
|
||||||
|
|
||||||
|
return chunk_results
|
||||||
|
|
||||||
|
def is_support_full_text_search(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check Milvus version support full text search.
|
||||||
|
Returns True if the version is >= 2.5.0.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
milvus_version_text = self._milvus_client.get_server_version()
|
||||||
|
pattern = r"v(\d+\.\d+\.\d+)"
|
||||||
|
match = re.search(pattern, milvus_version_text)
|
||||||
|
if match:
|
||||||
|
milvus_version = match.group(1)
|
||||||
|
logger.info(f"milvus version is {milvus_version}")
|
||||||
|
# Check if the version is >= 2.5.0
|
||||||
|
return milvus_version >= "2.5.0"
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to check Milvus version:{str(e)}."
|
||||||
|
f"do not support full text index."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
@@ -294,7 +294,7 @@ async def query_page(
|
|||||||
Returns:
|
Returns:
|
||||||
ServerResponse: The response
|
ServerResponse: The response
|
||||||
"""
|
"""
|
||||||
return Result.succ(service.get_document_list({}, page, page_size))
|
return Result.succ(service.get_document_list_page({}, page, page_size))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/documents/chunks/add")
|
@router.post("/documents/chunks/add")
|
||||||
|
@@ -1,17 +1,25 @@
|
|||||||
|
import ast
|
||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from dbgpt.component import ComponentType, SystemApp
|
from dbgpt.component import ComponentType, SystemApp
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk, Document, LLMClient
|
||||||
|
from dbgpt.model import DefaultLLMClient
|
||||||
|
from dbgpt.model.cluster import WorkerManagerFactory
|
||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
|
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
|
||||||
from dbgpt.rag.retriever.base import BaseRetriever
|
from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
|
||||||
|
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
from dbgpt.util.executor_utils import ExecutorFactory
|
||||||
|
from dbgpt_ext.rag.retriever.doc_tree import TreeNode
|
||||||
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao
|
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao
|
||||||
from dbgpt_serve.rag.retriever.qa_retriever import QARetriever
|
from dbgpt_serve.rag.retriever.qa_retriever import QARetriever
|
||||||
from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain
|
from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain
|
||||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSpaceRetriever(BaseRetriever):
|
class KnowledgeSpaceRetriever(BaseRetriever):
|
||||||
"""Knowledge Space retriever."""
|
"""Knowledge Space retriever."""
|
||||||
@@ -24,6 +32,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
rerank: Optional[Ranker] = None,
|
rerank: Optional[Ranker] = None,
|
||||||
llm_model: Optional[str] = None,
|
llm_model: Optional[str] = None,
|
||||||
embedding_model: Optional[str] = None,
|
embedding_model: Optional[str] = None,
|
||||||
|
retrieve_mode: Optional[str] = None,
|
||||||
system_app: SystemApp = None,
|
system_app: SystemApp = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -41,6 +50,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
self._llm_model = llm_model
|
self._llm_model = llm_model
|
||||||
app_config = system_app.config.configs.get("app_config")
|
app_config = system_app.config.configs.get("app_config")
|
||||||
self._top_k = top_k or app_config.rag.similarity_top_k
|
self._top_k = top_k or app_config.rag.similarity_top_k
|
||||||
|
self._retrieve_mode = retrieve_mode or RetrieverStrategy.HYBRID.value
|
||||||
self._embedding_model = embedding_model or app_config.models.default_embedding
|
self._embedding_model = embedding_model or app_config.models.default_embedding
|
||||||
self._system_app = system_app
|
self._system_app = system_app
|
||||||
embedding_factory = self._system_app.get_component(
|
embedding_factory = self._system_app.get_component(
|
||||||
@@ -49,14 +59,14 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
embedding_fn = embedding_factory.create()
|
embedding_fn = embedding_factory.create()
|
||||||
|
|
||||||
space_dao = KnowledgeSpaceDao()
|
space_dao = KnowledgeSpaceDao()
|
||||||
space = space_dao.get_one({"id": space_id})
|
self._space = space_dao.get_one({"id": space_id})
|
||||||
if space is None:
|
if self._space is None:
|
||||||
space = space_dao.get_one({"name": space_id})
|
self._space = space_dao.get_one({"name": space_id})
|
||||||
if space is None:
|
if self._space is None:
|
||||||
raise ValueError(f"Knowledge space {space_id} not found")
|
raise ValueError(f"Knowledge space {space_id} not found")
|
||||||
storage_connector = self.storage_manager.get_storage_connector(
|
self._storage_connector = self.storage_manager.get_storage_connector(
|
||||||
space.name,
|
self._space.name,
|
||||||
space.vector_type,
|
self._space.vector_type,
|
||||||
self._llm_model,
|
self._llm_model,
|
||||||
)
|
)
|
||||||
self._executor = self._system_app.get_component(
|
self._executor = self._system_app.get_component(
|
||||||
@@ -72,7 +82,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
system_app=system_app,
|
system_app=system_app,
|
||||||
),
|
),
|
||||||
EmbeddingRetriever(
|
EmbeddingRetriever(
|
||||||
index_store=storage_connector,
|
index_store=self._storage_connector,
|
||||||
top_k=self._top_k,
|
top_k=self._top_k,
|
||||||
query_rewrite=self._query_rewrite,
|
query_rewrite=self._query_rewrite,
|
||||||
rerank=self._rerank,
|
rerank=self._rerank,
|
||||||
@@ -85,6 +95,19 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
def storage_manager(self):
|
def storage_manager(self):
|
||||||
return StorageManager.get_instance(self._system_app)
|
return StorageManager.get_instance(self._system_app)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rag_service(self):
|
||||||
|
from dbgpt_serve.rag.service.service import Service as RagService
|
||||||
|
|
||||||
|
return RagService.get_instance(self._system_app)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_client(self) -> LLMClient:
|
||||||
|
worker_manager = self._system_app.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
|
return DefaultLLMClient(worker_manager, True)
|
||||||
|
|
||||||
def _retrieve(
|
def _retrieve(
|
||||||
self, query: str, filters: Optional[MetadataFilters] = None
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
@@ -133,9 +156,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
candidates = await blocking_func_to_async(
|
candidates = await self._aretrieve_with_score(query, 0.0, filters)
|
||||||
self._executor, self._retrieve, query, filters
|
|
||||||
)
|
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
@@ -146,6 +167,54 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): query text.
|
||||||
|
score_threshold (float): score threshold.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks with score.
|
||||||
|
"""
|
||||||
|
if self._retrieve_mode == RetrieverStrategy.SEMANTIC.value:
|
||||||
|
logger.info("Starting Semantic retrieval")
|
||||||
|
return await self.semantic_retrieve(query, score_threshold, filters)
|
||||||
|
elif self._retrieve_mode == RetrieverStrategy.KEYWORD.value:
|
||||||
|
logger.info("Starting Full Text retrieval")
|
||||||
|
return await self.full_text_retrieve(query, self._top_k, filters)
|
||||||
|
elif self._retrieve_mode == RetrieverStrategy.Tree.value:
|
||||||
|
logger.info("Starting Doc Tree retrieval")
|
||||||
|
return await self.tree_index_retrieve(query, self._top_k, filters)
|
||||||
|
elif self._retrieve_mode == RetrieverStrategy.HYBRID.value:
|
||||||
|
logger.info("Starting Hybrid retrieval")
|
||||||
|
tasks = []
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
tasks.append(self.semantic_retrieve(query, score_threshold, filters))
|
||||||
|
tasks.append(self.full_text_retrieve(query, self._top_k, filters))
|
||||||
|
tasks.append(self.tree_index_retrieve(query, self._top_k, filters))
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
semantic_candidates = results[0]
|
||||||
|
full_text_candidates = results[1]
|
||||||
|
tree_candidates = results[2]
|
||||||
|
logger.info(
|
||||||
|
f"Hybrid retrieval completed. "
|
||||||
|
f"Found {len(semantic_candidates)} semantic candidates "
|
||||||
|
f"and Found {len(full_text_candidates)} full text candidates."
|
||||||
|
f"and Found {len(tree_candidates)} tree candidates."
|
||||||
|
)
|
||||||
|
candidates = semantic_candidates + full_text_candidates + tree_candidates
|
||||||
|
# Remove duplicates
|
||||||
|
unique_candidates = {chunk.content: chunk for chunk in candidates}
|
||||||
|
return list(unique_candidates.values())
|
||||||
|
|
||||||
|
async def semantic_retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text.
|
query (str): query text.
|
||||||
score_threshold (float): score threshold.
|
score_threshold (float): score threshold.
|
||||||
@@ -157,3 +226,110 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
return await self._retriever_chain.aretrieve_with_scores(
|
return await self._retriever_chain.aretrieve_with_scores(
|
||||||
query, score_threshold, filters
|
query, score_threshold, filters
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def full_text_retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Full Text Retrieve knowledge chunks with score.
|
||||||
|
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/
|
||||||
|
index-modules-similarity.html;
|
||||||
|
TF/IDF or BM25 based similarity that has built-in tf normalization and is
|
||||||
|
supposed to work better for short fields (like names).
|
||||||
|
See Okapi_BM25 for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): query text.
|
||||||
|
top_k (int): top k limit.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks with score.
|
||||||
|
"""
|
||||||
|
if self._storage_connector.is_support_full_text_search():
|
||||||
|
return await self._storage_connector.afull_text_search(
|
||||||
|
query, top_k, filters
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Full text search is not supported for this storage connector."
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def tree_index_retrieve(
|
||||||
|
self, query: str, top_k: int, filters: Optional[MetadataFilters] = None
|
||||||
|
):
|
||||||
|
"""Search for keywords in the tree index."""
|
||||||
|
# Check if the keyword is in the node title
|
||||||
|
# If the node has children, recursively search in them
|
||||||
|
# If the node is a leaf, check if it contains the keyword
|
||||||
|
try:
|
||||||
|
docs_res = self.rag_service.get_document_list(
|
||||||
|
{
|
||||||
|
"space": self._space.name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
docs = []
|
||||||
|
for doc_res in docs_res:
|
||||||
|
doc = Document(
|
||||||
|
content=doc_res.content,
|
||||||
|
)
|
||||||
|
chunks_res = self.rag_service.get_chunk_list(
|
||||||
|
{
|
||||||
|
"document_id": doc_res.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
content=chunk_res.content,
|
||||||
|
metadata=ast.literal_eval(chunk_res.meta_info),
|
||||||
|
)
|
||||||
|
for chunk_res in chunks_res
|
||||||
|
]
|
||||||
|
doc.chunks = chunks
|
||||||
|
docs.append(doc)
|
||||||
|
keyword_extractor = KeywordExtractor(
|
||||||
|
llm_client=self.llm_client, model_name=self._llm_model
|
||||||
|
)
|
||||||
|
from dbgpt_ext.rag.retriever.doc_tree import DocTreeRetriever
|
||||||
|
|
||||||
|
tree_retriever = DocTreeRetriever(
|
||||||
|
docs=docs,
|
||||||
|
keywords_extractor=keyword_extractor,
|
||||||
|
top_k=self._top_k,
|
||||||
|
query_rewrite=self._query_rewrite,
|
||||||
|
with_content=True,
|
||||||
|
rerank=self._rerank,
|
||||||
|
)
|
||||||
|
candidates = []
|
||||||
|
tree_nodes = await tree_retriever.aretrieve_with_scores(
|
||||||
|
query, top_k, filters
|
||||||
|
)
|
||||||
|
# Convert tree nodes to chunks
|
||||||
|
for node in tree_nodes:
|
||||||
|
chunks = self._traverse(node)
|
||||||
|
candidates.extend(chunks)
|
||||||
|
return candidates
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in tree index retrieval: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _traverse(self, node: TreeNode):
|
||||||
|
"""Traverse the tree and search for the keyword."""
|
||||||
|
# Check if the node has children
|
||||||
|
result = []
|
||||||
|
if node.children:
|
||||||
|
for child in node.children:
|
||||||
|
result.extend(self._traverse(child))
|
||||||
|
else:
|
||||||
|
# If the node is a leaf, check if it contains the keyword
|
||||||
|
if node:
|
||||||
|
result.append(
|
||||||
|
Chunk(
|
||||||
|
content=node.content,
|
||||||
|
retriever=node.retriever,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
@@ -397,6 +397,21 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
|||||||
return self.dao.get_list_page(request, page, page_size)
|
return self.dao.get_list_page(request, page, page_size)
|
||||||
|
|
||||||
def get_document_list(
|
def get_document_list(
|
||||||
|
self, request: QUERY_SPEC
|
||||||
|
) -> PaginationResult[DocumentServeResponse]:
|
||||||
|
"""Get a list of Flow entities by page
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (SpaceServeRequest): The request
|
||||||
|
page (int): The page number
|
||||||
|
page_size (int): The page size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SpaceServeResponse]: The response
|
||||||
|
"""
|
||||||
|
return self._document_dao.get_list(request)
|
||||||
|
|
||||||
|
def get_document_list_page(
|
||||||
self, request: QUERY_SPEC, page: int, page_size: int
|
self, request: QUERY_SPEC, page: int, page_size: int
|
||||||
) -> PaginationResult[DocumentServeResponse]:
|
) -> PaginationResult[DocumentServeResponse]:
|
||||||
"""Get a list of Flow entities by page
|
"""Get a list of Flow entities by page
|
||||||
|
Reference in New Issue
Block a user