feat(ChatKnowledge):add hybrid search for knowledge space. (#2722)

This commit is contained in:
Aries-ckt
2025-05-23 18:54:22 +08:00
committed by GitHub
parent a10535d944
commit 7e9885574a
8 changed files with 358 additions and 48 deletions

View File

@@ -18,7 +18,9 @@ class RetrieverStrategy(str, Enum):
""" """
EMBEDDING = "embedding" EMBEDDING = "embedding"
SEMANTIC = "semantic"
GRAPH = "graph" GRAPH = "graph"
Tree = "tree"
KEYWORD = "keyword" KEYWORD = "keyword"
HYBRID = "hybrid" HYBRID = "hybrid"

View File

@@ -278,11 +278,8 @@ class IndexStoreBase(ABC):
def is_support_full_text_search(self) -> bool: def is_support_full_text_search(self) -> bool:
"""Support full text search. """Support full text search.
Args:
collection_name(str): collection name.
Return: Return:
bool: The similar documents. bool: The similar documents.
""" """
raise NotImplementedError( logger.warning("Full text search is not supported in this index store.")
"Full text search is not supported in this index store." return False
)

View File

@@ -226,6 +226,7 @@ class DocTreeRetriever(BaseRetriever):
rerank: Optional[Ranker] = None, rerank: Optional[Ranker] = None,
keywords_extractor: Optional[ExtractorBase] = None, keywords_extractor: Optional[ExtractorBase] = None,
with_content: bool = False, with_content: bool = False,
show_tree: bool = True,
executor: Optional[Executor] = None, executor: Optional[Executor] = None,
): ):
"""Create DocTreeRetriever. """Create DocTreeRetriever.
@@ -248,6 +249,7 @@ class DocTreeRetriever(BaseRetriever):
self._rerank = rerank or DefaultRanker(self._top_k) self._rerank = rerank or DefaultRanker(self._top_k)
self._keywords_extractor = keywords_extractor self._keywords_extractor = keywords_extractor
self._with_content = with_content self._with_content = with_content
self._show_tree = show_tree
self._tree_indexes = self._initialize_doc_tree(docs) self._tree_indexes = self._initialize_doc_tree(docs)
self._executor = executor or ThreadPoolExecutor() self._executor = executor or ThreadPoolExecutor()
@@ -305,6 +307,11 @@ class DocTreeRetriever(BaseRetriever):
retrieve_node = tree_index.search_keywords(tree_index.root, keyword) retrieve_node = tree_index.search_keywords(tree_index.root, keyword)
if retrieve_node: if retrieve_node:
# If a match is found, return the corresponding chunks # If a match is found, return the corresponding chunks
if self._show_tree:
logger.info(
f"DocTreeIndex Match found in: {retrieve_node.level_text}"
)
tree_index.display_tree(retrieve_node)
all_nodes.append(retrieve_node) all_nodes.append(retrieve_node)
return all_nodes return all_nodes
@@ -335,12 +342,11 @@ class DocTreeRetriever(BaseRetriever):
for doc in docs: for doc in docs:
tree_index = DocTreeIndex() tree_index = DocTreeIndex()
for chunk in doc.chunks: for chunk in doc.chunks:
if not chunk.metadata.get(TITLE): title = chunk.metadata.get("title") or "title"
continue
if not self._with_content: if not self._with_content:
tree_index.add_nodes( tree_index.add_nodes(
node_id=chunk.chunk_id, node_id=chunk.chunk_id,
title=chunk.metadata[TITLE], title=title,
header1=chunk.metadata.get(HEADER1), header1=chunk.metadata.get(HEADER1),
header2=chunk.metadata.get(HEADER2), header2=chunk.metadata.get(HEADER2),
header3=chunk.metadata.get(HEADER3), header3=chunk.metadata.get(HEADER3),
@@ -351,7 +357,7 @@ class DocTreeRetriever(BaseRetriever):
else: else:
tree_index.add_nodes( tree_index.add_nodes(
node_id=chunk.chunk_id, node_id=chunk.chunk_id,
title=chunk.metadata[TITLE], title=title,
header1=chunk.metadata.get(HEADER1), header1=chunk.metadata.get(HEADER1),
header2=chunk.metadata.get(HEADER2), header2=chunk.metadata.get(HEADER2),
header3=chunk.metadata.get(HEADER3), header3=chunk.metadata.get(HEADER3),

View File

@@ -218,6 +218,31 @@ class ChromaStore(VectorStoreBase):
] ]
return self.filter_by_score_threshold(chunks, score_threshold) return self.filter_by_score_threshold(chunks, score_threshold)
async def afull_text_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
logger.info("ChromaStore do not support full text search")
return []
def is_support_full_text_search(self) -> bool:
"""Support full text search.
Args:
collection_name(str): collection name.
Return:
bool: is support full texts earch.
"""
return False
def vector_name_exists(self) -> bool: def vector_name_exists(self) -> bool:
"""Whether vector name exists.""" """Whether vector name exists."""
try: try:

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pymilvus.milvus_client import IndexParams, MilvusClient
from dbgpt.core import Chunk, Embeddings from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import ( from dbgpt.storage.vector_store.base import (
@@ -212,13 +215,13 @@ class MilvusStore(VectorStoreBase):
) )
self._vector_store_config = vector_store_config self._vector_store_config = vector_store_config
try: # try:
from pymilvus import connections # from pymilvus import connections
except ImportError: # except ImportError:
raise ValueError( # raise ValueError(
"Could not import pymilvus python package. " # "Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`." # "Please install it with `pip install pymilvus`."
) # )
connect_kwargs = {} connect_kwargs = {}
milvus_vector_config = vector_store_config.to_dict() milvus_vector_config = vector_store_config.to_dict()
self.uri = milvus_vector_config.get("uri") or os.getenv( self.uri = milvus_vector_config.get("uri") or os.getenv(
@@ -227,7 +230,9 @@ class MilvusStore(VectorStoreBase):
self.port = milvus_vector_config.get("post") or os.getenv( self.port = milvus_vector_config.get("post") or os.getenv(
"MILVUS_PORT", "19530" "MILVUS_PORT", "19530"
) )
self.username = milvus_vector_config.get("user") or os.getenv("MILVUS_USERNAME") self.username = milvus_vector_config.get("user", "") or os.getenv(
"MILVUS_USERNAME"
)
self.password = milvus_vector_config.get("password") or os.getenv( self.password = milvus_vector_config.get("password") or os.getenv(
"MILVUS_PASSWORD" "MILVUS_PASSWORD"
) )
@@ -245,12 +250,12 @@ class MilvusStore(VectorStoreBase):
self.embedding: Embeddings = embedding_fn self.embedding: Embeddings = embedding_fn
self.fields: List = [] self.fields: List = []
self.alias = milvus_vector_config.get("alias") or "default" self.alias = milvus_vector_config.get("alias") or "default"
self._consistency_level = "Session"
# use HNSW by default. # use HNSW by default.
self.index_params = { self.index_params = {
"index_type": "HNSW", "index_type": "HNSW",
"metric_type": "COSINE", "metric_type": "COSINE",
"params": {"M": 8, "efConstruction": 64},
} }
# use HNSW by default. # use HNSW by default.
@@ -269,6 +274,9 @@ class MilvusStore(VectorStoreBase):
self.primary_field = milvus_vector_config.get("primary_field") or "pk_id" self.primary_field = milvus_vector_config.get("primary_field") or "pk_id"
self.vector_field = milvus_vector_config.get("embedding_field") or "vector" self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
self.text_field = milvus_vector_config.get("text_field") or "content" self.text_field = milvus_vector_config.get("text_field") or "content"
self.sparse_vector = (
milvus_vector_config.get("sparse_vector") or "sparse_vector"
)
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata" self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
self.props_field = milvus_vector_config.get("props_field") or "props_field" self.props_field = milvus_vector_config.get("props_field") or "props_field"
@@ -281,12 +289,9 @@ class MilvusStore(VectorStoreBase):
connect_kwargs["user"] = self.username connect_kwargs["user"] = self.username
connect_kwargs["password"] = self.password connect_kwargs["password"] = self.password
connections.connect( url = f"http://{self.uri}:{self.port}"
host=self.uri or "127.0.0.1", self._milvus_client = MilvusClient(
port=self.port or "19530", uri=url, user=self.username, db_name="default"
user=self.username,
password=self.password,
alias="default",
) )
self.col = self.create_collection(collection_name=self.collection_name) self.col = self.create_collection(collection_name=self.collection_name)
@@ -305,6 +310,8 @@ class MilvusStore(VectorStoreBase):
CollectionSchema, CollectionSchema,
DataType, DataType,
FieldSchema, FieldSchema,
Function,
FunctionType,
connections, connections,
utility, utility,
) )
@@ -333,30 +340,57 @@ class MilvusStore(VectorStoreBase):
vector_field = self.vector_field vector_field = self.vector_field
text_field = self.text_field text_field = self.text_field
metadata_field = self.metadata_field metadata_field = self.metadata_field
sparse_vector = self.sparse_vector
props_field = self.props_field props_field = self.props_field
fields = [] fields = []
# max_length = 0 # max_length = 0
# Create the text field # Create the text field
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535)) fields.append(
FieldSchema(
text_field,
DataType.VARCHAR,
max_length=65535,
enable_analyzer=self.is_support_full_text_search(),
)
)
# primary key field # primary key field
fields.append( fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True) FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
) )
# vector field # vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
if self.is_support_full_text_search():
fields.append(FieldSchema(sparse_vector, DataType.SPARSE_FLOAT_VECTOR))
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535)) fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
fields.append(FieldSchema(props_field, DataType.JSON)) fields.append(FieldSchema(props_field, DataType.JSON))
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
if self.is_support_full_text_search():
bm25_fn = Function(
name="text_bm25_emb",
input_field_names=[self.text_field],
output_field_names=[self.sparse_vector],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_fn)
# Create the collection # Create the collection
collection = Collection(collection_name, schema) collection = Collection(collection_name, schema)
self.col = collection self.col = collection
index_params = IndexParams()
# index parameters for the collection # index parameters for the collection
index = self.index_params index_params.add_index(field_name=self.vector_field, **self.index_params)
# milvus index # Create Sparse Vector Index for the collection
collection.create_index(vector_field, index) if self.is_support_full_text_search():
collection.create_index(
self.sparse_vector,
{
"index_type": "AUTOINDEX",
"metric_type": "BM25",
},
)
collection.create_index(vector_field, self.index_params)
collection.load() collection.load()
return collection return self.col
def _load_documents(self, documents) -> List[str]: def _load_documents(self, documents) -> List[str]:
"""Load documents into Milvus. """Load documents into Milvus.
@@ -418,7 +452,7 @@ class MilvusStore(VectorStoreBase):
insert_dict.setdefault("metadata", []).append(metadata_json) insert_dict.setdefault("metadata", []).append(metadata_json)
insert_dict.setdefault("props_field", []).append(metadata_json) insert_dict.setdefault("props_field", []).append(metadata_json)
# Convert dict to list of lists for insertion # Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields] insert_list = [insert_dict[x] for x in self.fields if self.sparse_vector != x]
# Insert into the collection. # Insert into the collection.
res = self.col.insert( res = self.col.insert(
insert_list, partition_name=partition_name, timeout=timeout insert_list, partition_name=partition_name, timeout=timeout
@@ -570,13 +604,17 @@ class MilvusStore(VectorStoreBase):
self.col.load() self.col.load()
# use default index params. # use default index params.
if param is None: if param is None:
index_type = self.col.indexes[0].params["index_type"] for index in self.col.indexes:
param = self.index_params_map[index_type] if index.params["index_type"] == self.index_params.get("index_type"):
param = index.params
break
# query text embedding. # query text embedding.
query_vector = self.embedding.embed_query(query) query_vector = self.embedding.embed_query(query)
# Determine result metadata fields. # Determine result metadata fields.
output_fields = self.fields[:] output_fields = self.fields[:]
output_fields.remove(self.vector_field) output_fields.remove(self.vector_field)
if self.sparse_vector in output_fields:
output_fields.remove(self.sparse_vector)
# milvus search. # milvus search.
res = self.col.search( res = self.col.search(
[query_vector], [query_vector],
@@ -595,7 +633,10 @@ class MilvusStore(VectorStoreBase):
meta = {x: result.entity.get(x) for x in output_fields} meta = {x: result.entity.get(x) for x in output_fields}
ret.append( ret.append(
( (
Chunk(content=meta.pop(self.text_field), metadata=meta), Chunk(
content=meta.pop(self.text_field),
metadata=json.loads(meta.pop(self.metadata_field)),
),
result.distance, result.distance,
result.id, result.id,
) )
@@ -693,3 +734,51 @@ class MilvusStore(VectorStoreBase):
utility.drop_collection(self.collection_name) utility.drop_collection(self.collection_name)
logger.info(f"truncate milvus collection {self.collection_name} success") logger.info(f"truncate milvus collection {self.collection_name} success")
def full_text_search(
self, text: str, topk: int = 10, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
if self.is_support_full_text_search():
milvus_filters = self.convert_metadata_filters(filters) if filters else None
results = self._milvus_client.search(
collection_name=self.collection_name,
data=[text],
anns_field=self.sparse_vector,
limit=topk,
output_fields=["*"],
filter=milvus_filters,
)
chunk_results = [
Chunk(
content=r.get("entity").get("content"),
chunk_id=str(r.get("pk_id")),
score=r.get("distance"),
metadata=json.loads(r.get("entity").get("metadata")),
retriever="full_text",
)
for r in results[0]
]
return chunk_results
def is_support_full_text_search(self) -> bool:
"""
Check Milvus version support full text search.
Returns True if the version is >= 2.5.0.
"""
try:
milvus_version_text = self._milvus_client.get_server_version()
pattern = r"v(\d+\.\d+\.\d+)"
match = re.search(pattern, milvus_version_text)
if match:
milvus_version = match.group(1)
logger.info(f"milvus version is {milvus_version}")
# Check if the version is >= 2.5.0
return milvus_version >= "2.5.0"
return False
except Exception as e:
logger.warning(
f"Failed to check Milvus version:{str(e)}."
f"do not support full text index."
)
return False

View File

@@ -294,7 +294,7 @@ async def query_page(
Returns: Returns:
ServerResponse: The response ServerResponse: The response
""" """
return Result.succ(service.get_document_list({}, page, page_size)) return Result.succ(service.get_document_list_page({}, page, page_size))
@router.post("/documents/chunks/add") @router.post("/documents/chunks/add")

View File

@@ -1,17 +1,25 @@
import ast
import logging
from typing import List, Optional from typing import List, Optional
from dbgpt.component import ComponentType, SystemApp from dbgpt.component import ComponentType, SystemApp
from dbgpt.core import Chunk from dbgpt.core import Chunk, Document, LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.executor_utils import ExecutorFactory
from dbgpt_ext.rag.retriever.doc_tree import TreeNode
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao from dbgpt_serve.rag.models.models import KnowledgeSpaceDao
from dbgpt_serve.rag.retriever.qa_retriever import QARetriever from dbgpt_serve.rag.retriever.qa_retriever import QARetriever
from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain
from dbgpt_serve.rag.storage_manager import StorageManager from dbgpt_serve.rag.storage_manager import StorageManager
logger = logging.getLogger(__name__)
class KnowledgeSpaceRetriever(BaseRetriever): class KnowledgeSpaceRetriever(BaseRetriever):
"""Knowledge Space retriever.""" """Knowledge Space retriever."""
@@ -24,6 +32,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
rerank: Optional[Ranker] = None, rerank: Optional[Ranker] = None,
llm_model: Optional[str] = None, llm_model: Optional[str] = None,
embedding_model: Optional[str] = None, embedding_model: Optional[str] = None,
retrieve_mode: Optional[str] = None,
system_app: SystemApp = None, system_app: SystemApp = None,
): ):
""" """
@@ -41,6 +50,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
self._llm_model = llm_model self._llm_model = llm_model
app_config = system_app.config.configs.get("app_config") app_config = system_app.config.configs.get("app_config")
self._top_k = top_k or app_config.rag.similarity_top_k self._top_k = top_k or app_config.rag.similarity_top_k
self._retrieve_mode = retrieve_mode or RetrieverStrategy.HYBRID.value
self._embedding_model = embedding_model or app_config.models.default_embedding self._embedding_model = embedding_model or app_config.models.default_embedding
self._system_app = system_app self._system_app = system_app
embedding_factory = self._system_app.get_component( embedding_factory = self._system_app.get_component(
@@ -49,14 +59,14 @@ class KnowledgeSpaceRetriever(BaseRetriever):
embedding_fn = embedding_factory.create() embedding_fn = embedding_factory.create()
space_dao = KnowledgeSpaceDao() space_dao = KnowledgeSpaceDao()
space = space_dao.get_one({"id": space_id}) self._space = space_dao.get_one({"id": space_id})
if space is None: if self._space is None:
space = space_dao.get_one({"name": space_id}) self._space = space_dao.get_one({"name": space_id})
if space is None: if self._space is None:
raise ValueError(f"Knowledge space {space_id} not found") raise ValueError(f"Knowledge space {space_id} not found")
storage_connector = self.storage_manager.get_storage_connector( self._storage_connector = self.storage_manager.get_storage_connector(
space.name, self._space.name,
space.vector_type, self._space.vector_type,
self._llm_model, self._llm_model,
) )
self._executor = self._system_app.get_component( self._executor = self._system_app.get_component(
@@ -72,7 +82,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
system_app=system_app, system_app=system_app,
), ),
EmbeddingRetriever( EmbeddingRetriever(
index_store=storage_connector, index_store=self._storage_connector,
top_k=self._top_k, top_k=self._top_k,
query_rewrite=self._query_rewrite, query_rewrite=self._query_rewrite,
rerank=self._rerank, rerank=self._rerank,
@@ -85,6 +95,19 @@ class KnowledgeSpaceRetriever(BaseRetriever):
def storage_manager(self): def storage_manager(self):
return StorageManager.get_instance(self._system_app) return StorageManager.get_instance(self._system_app)
@property
def rag_service(self):
from dbgpt_serve.rag.service.service import Service as RagService
return RagService.get_instance(self._system_app)
@property
def llm_client(self) -> LLMClient:
worker_manager = self._system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(worker_manager, True)
def _retrieve( def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
@@ -133,9 +156,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
candidates = await blocking_func_to_async( candidates = await self._aretrieve_with_score(query, 0.0, filters)
self._executor, self._retrieve, query, filters
)
return candidates return candidates
async def _aretrieve_with_score( async def _aretrieve_with_score(
@@ -146,6 +167,54 @@ class KnowledgeSpaceRetriever(BaseRetriever):
) -> List[Chunk]: ) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args:
query (str): query text.
score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score.
"""
if self._retrieve_mode == RetrieverStrategy.SEMANTIC.value:
logger.info("Starting Semantic retrieval")
return await self.semantic_retrieve(query, score_threshold, filters)
elif self._retrieve_mode == RetrieverStrategy.KEYWORD.value:
logger.info("Starting Full Text retrieval")
return await self.full_text_retrieve(query, self._top_k, filters)
elif self._retrieve_mode == RetrieverStrategy.Tree.value:
logger.info("Starting Doc Tree retrieval")
return await self.tree_index_retrieve(query, self._top_k, filters)
elif self._retrieve_mode == RetrieverStrategy.HYBRID.value:
logger.info("Starting Hybrid retrieval")
tasks = []
import asyncio
tasks.append(self.semantic_retrieve(query, score_threshold, filters))
tasks.append(self.full_text_retrieve(query, self._top_k, filters))
tasks.append(self.tree_index_retrieve(query, self._top_k, filters))
results = await asyncio.gather(*tasks)
semantic_candidates = results[0]
full_text_candidates = results[1]
tree_candidates = results[2]
logger.info(
f"Hybrid retrieval completed. "
f"Found {len(semantic_candidates)} semantic candidates "
f"and Found {len(full_text_candidates)} full text candidates."
f"and Found {len(tree_candidates)} tree candidates."
)
candidates = semantic_candidates + full_text_candidates + tree_candidates
# Remove duplicates
unique_candidates = {chunk.content: chunk for chunk in candidates}
return list(unique_candidates.values())
async def semantic_retrieve(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args: Args:
query (str): query text. query (str): query text.
score_threshold (float): score threshold. score_threshold (float): score threshold.
@@ -157,3 +226,110 @@ class KnowledgeSpaceRetriever(BaseRetriever):
return await self._retriever_chain.aretrieve_with_scores( return await self._retriever_chain.aretrieve_with_scores(
query, score_threshold, filters query, score_threshold, filters
) )
async def full_text_retrieve(
self,
query: str,
top_k: int,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Full Text Retrieve knowledge chunks with score.
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/
index-modules-similarity.html;
TF/IDF or BM25 based similarity that has built-in tf normalization and is
supposed to work better for short fields (like names).
See Okapi_BM25 for more details.
Args:
query (str): query text.
top_k (int): top k limit.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score.
"""
if self._storage_connector.is_support_full_text_search():
return await self._storage_connector.afull_text_search(
query, top_k, filters
)
else:
logger.warning(
"Full text search is not supported for this storage connector."
)
return []
async def tree_index_retrieve(
self, query: str, top_k: int, filters: Optional[MetadataFilters] = None
):
"""Search for keywords in the tree index."""
# Check if the keyword is in the node title
# If the node has children, recursively search in them
# If the node is a leaf, check if it contains the keyword
try:
docs_res = self.rag_service.get_document_list(
{
"space": self._space.name,
}
)
docs = []
for doc_res in docs_res:
doc = Document(
content=doc_res.content,
)
chunks_res = self.rag_service.get_chunk_list(
{
"document_id": doc_res.id,
}
)
chunks = [
Chunk(
content=chunk_res.content,
metadata=ast.literal_eval(chunk_res.meta_info),
)
for chunk_res in chunks_res
]
doc.chunks = chunks
docs.append(doc)
keyword_extractor = KeywordExtractor(
llm_client=self.llm_client, model_name=self._llm_model
)
from dbgpt_ext.rag.retriever.doc_tree import DocTreeRetriever
tree_retriever = DocTreeRetriever(
docs=docs,
keywords_extractor=keyword_extractor,
top_k=self._top_k,
query_rewrite=self._query_rewrite,
with_content=True,
rerank=self._rerank,
)
candidates = []
tree_nodes = await tree_retriever.aretrieve_with_scores(
query, top_k, filters
)
# Convert tree nodes to chunks
for node in tree_nodes:
chunks = self._traverse(node)
candidates.extend(chunks)
return candidates
except Exception as e:
logger.error(f"Error in tree index retrieval: {e}")
return []
def _traverse(self, node: TreeNode):
"""Traverse the tree and search for the keyword."""
# Check if the node has children
result = []
if node.children:
for child in node.children:
result.extend(self._traverse(child))
else:
# If the node is a leaf, check if it contains the keyword
if node:
result.append(
Chunk(
content=node.content,
retriever=node.retriever,
)
)
return result

View File

@@ -397,6 +397,21 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
return self.dao.get_list_page(request, page, page_size) return self.dao.get_list_page(request, page, page_size)
def get_document_list( def get_document_list(
self, request: QUERY_SPEC
) -> PaginationResult[DocumentServeResponse]:
"""Get a list of Flow entities by page
Args:
request (SpaceServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[SpaceServeResponse]: The response
"""
return self._document_dao.get_list(request)
def get_document_list_page(
self, request: QUERY_SPEC, page: int, page_size: int self, request: QUERY_SPEC, page: int, page_size: int
) -> PaginationResult[DocumentServeResponse]: ) -> PaginationResult[DocumentServeResponse]:
"""Get a list of Flow entities by page """Get a list of Flow entities by page