mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
feat(ChatKnowledge):add hybrid search for knowledge space. (#2722)
This commit is contained in:
@@ -18,7 +18,9 @@ class RetrieverStrategy(str, Enum):
|
||||
"""
|
||||
|
||||
EMBEDDING = "embedding"
|
||||
SEMANTIC = "semantic"
|
||||
GRAPH = "graph"
|
||||
Tree = "tree"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
@@ -278,11 +278,8 @@ class IndexStoreBase(ABC):
|
||||
def is_support_full_text_search(self) -> bool:
|
||||
"""Support full text search.
|
||||
|
||||
Args:
|
||||
collection_name(str): collection name.
|
||||
Return:
|
||||
bool: The similar documents.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Full text search is not supported in this index store."
|
||||
)
|
||||
logger.warning("Full text search is not supported in this index store.")
|
||||
return False
|
||||
|
@@ -226,6 +226,7 @@ class DocTreeRetriever(BaseRetriever):
|
||||
rerank: Optional[Ranker] = None,
|
||||
keywords_extractor: Optional[ExtractorBase] = None,
|
||||
with_content: bool = False,
|
||||
show_tree: bool = True,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
"""Create DocTreeRetriever.
|
||||
@@ -248,6 +249,7 @@ class DocTreeRetriever(BaseRetriever):
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
self._keywords_extractor = keywords_extractor
|
||||
self._with_content = with_content
|
||||
self._show_tree = show_tree
|
||||
self._tree_indexes = self._initialize_doc_tree(docs)
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
|
||||
@@ -305,6 +307,11 @@ class DocTreeRetriever(BaseRetriever):
|
||||
retrieve_node = tree_index.search_keywords(tree_index.root, keyword)
|
||||
if retrieve_node:
|
||||
# If a match is found, return the corresponding chunks
|
||||
if self._show_tree:
|
||||
logger.info(
|
||||
f"DocTreeIndex Match found in: {retrieve_node.level_text}"
|
||||
)
|
||||
tree_index.display_tree(retrieve_node)
|
||||
all_nodes.append(retrieve_node)
|
||||
return all_nodes
|
||||
|
||||
@@ -335,12 +342,11 @@ class DocTreeRetriever(BaseRetriever):
|
||||
for doc in docs:
|
||||
tree_index = DocTreeIndex()
|
||||
for chunk in doc.chunks:
|
||||
if not chunk.metadata.get(TITLE):
|
||||
continue
|
||||
title = chunk.metadata.get("title") or "title"
|
||||
if not self._with_content:
|
||||
tree_index.add_nodes(
|
||||
node_id=chunk.chunk_id,
|
||||
title=chunk.metadata[TITLE],
|
||||
title=title,
|
||||
header1=chunk.metadata.get(HEADER1),
|
||||
header2=chunk.metadata.get(HEADER2),
|
||||
header3=chunk.metadata.get(HEADER3),
|
||||
@@ -351,7 +357,7 @@ class DocTreeRetriever(BaseRetriever):
|
||||
else:
|
||||
tree_index.add_nodes(
|
||||
node_id=chunk.chunk_id,
|
||||
title=chunk.metadata[TITLE],
|
||||
title=title,
|
||||
header1=chunk.metadata.get(HEADER1),
|
||||
header2=chunk.metadata.get(HEADER2),
|
||||
header3=chunk.metadata.get(HEADER3),
|
||||
|
@@ -218,6 +218,31 @@ class ChromaStore(VectorStoreBase):
|
||||
]
|
||||
return self.filter_by_score_threshold(chunks, score_threshold)
|
||||
|
||||
async def afull_text_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search in index database.
|
||||
|
||||
Args:
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
logger.info("ChromaStore do not support full text search")
|
||||
return []
|
||||
|
||||
def is_support_full_text_search(self) -> bool:
|
||||
"""Support full text search.
|
||||
|
||||
Args:
|
||||
collection_name(str): collection name.
|
||||
Return:
|
||||
bool: is support full texts earch.
|
||||
"""
|
||||
return False
|
||||
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Whether vector name exists."""
|
||||
try:
|
||||
|
@@ -5,9 +5,12 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from pymilvus.milvus_client import IndexParams, MilvusClient
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
@@ -212,13 +215,13 @@ class MilvusStore(VectorStoreBase):
|
||||
)
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
try:
|
||||
from pymilvus import connections
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
# try:
|
||||
# from pymilvus import connections
|
||||
# except ImportError:
|
||||
# raise ValueError(
|
||||
# "Could not import pymilvus python package. "
|
||||
# "Please install it with `pip install pymilvus`."
|
||||
# )
|
||||
connect_kwargs = {}
|
||||
milvus_vector_config = vector_store_config.to_dict()
|
||||
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
||||
@@ -227,7 +230,9 @@ class MilvusStore(VectorStoreBase):
|
||||
self.port = milvus_vector_config.get("post") or os.getenv(
|
||||
"MILVUS_PORT", "19530"
|
||||
)
|
||||
self.username = milvus_vector_config.get("user") or os.getenv("MILVUS_USERNAME")
|
||||
self.username = milvus_vector_config.get("user", "") or os.getenv(
|
||||
"MILVUS_USERNAME"
|
||||
)
|
||||
self.password = milvus_vector_config.get("password") or os.getenv(
|
||||
"MILVUS_PASSWORD"
|
||||
)
|
||||
@@ -245,12 +250,12 @@ class MilvusStore(VectorStoreBase):
|
||||
self.embedding: Embeddings = embedding_fn
|
||||
self.fields: List = []
|
||||
self.alias = milvus_vector_config.get("alias") or "default"
|
||||
self._consistency_level = "Session"
|
||||
|
||||
# use HNSW by default.
|
||||
self.index_params = {
|
||||
"index_type": "HNSW",
|
||||
"metric_type": "COSINE",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
|
||||
# use HNSW by default.
|
||||
@@ -269,6 +274,9 @@ class MilvusStore(VectorStoreBase):
|
||||
self.primary_field = milvus_vector_config.get("primary_field") or "pk_id"
|
||||
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
||||
self.text_field = milvus_vector_config.get("text_field") or "content"
|
||||
self.sparse_vector = (
|
||||
milvus_vector_config.get("sparse_vector") or "sparse_vector"
|
||||
)
|
||||
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
||||
self.props_field = milvus_vector_config.get("props_field") or "props_field"
|
||||
|
||||
@@ -281,12 +289,9 @@ class MilvusStore(VectorStoreBase):
|
||||
connect_kwargs["user"] = self.username
|
||||
connect_kwargs["password"] = self.password
|
||||
|
||||
connections.connect(
|
||||
host=self.uri or "127.0.0.1",
|
||||
port=self.port or "19530",
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
alias="default",
|
||||
url = f"http://{self.uri}:{self.port}"
|
||||
self._milvus_client = MilvusClient(
|
||||
uri=url, user=self.username, db_name="default"
|
||||
)
|
||||
self.col = self.create_collection(collection_name=self.collection_name)
|
||||
|
||||
@@ -305,6 +310,8 @@ class MilvusStore(VectorStoreBase):
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
Function,
|
||||
FunctionType,
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
@@ -333,30 +340,57 @@ class MilvusStore(VectorStoreBase):
|
||||
vector_field = self.vector_field
|
||||
text_field = self.text_field
|
||||
metadata_field = self.metadata_field
|
||||
sparse_vector = self.sparse_vector
|
||||
props_field = self.props_field
|
||||
fields = []
|
||||
# max_length = 0
|
||||
# Create the text field
|
||||
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
text_field,
|
||||
DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=self.is_support_full_text_search(),
|
||||
)
|
||||
)
|
||||
# primary key field
|
||||
fields.append(
|
||||
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
||||
)
|
||||
# vector field
|
||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||
if self.is_support_full_text_search():
|
||||
fields.append(FieldSchema(sparse_vector, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
||||
fields.append(FieldSchema(props_field, DataType.JSON))
|
||||
schema = CollectionSchema(fields)
|
||||
if self.is_support_full_text_search():
|
||||
bm25_fn = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[self.text_field],
|
||||
output_field_names=[self.sparse_vector],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_fn)
|
||||
# Create the collection
|
||||
collection = Collection(collection_name, schema)
|
||||
self.col = collection
|
||||
index_params = IndexParams()
|
||||
# index parameters for the collection
|
||||
index = self.index_params
|
||||
# milvus index
|
||||
collection.create_index(vector_field, index)
|
||||
index_params.add_index(field_name=self.vector_field, **self.index_params)
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self.is_support_full_text_search():
|
||||
collection.create_index(
|
||||
self.sparse_vector,
|
||||
{
|
||||
"index_type": "AUTOINDEX",
|
||||
"metric_type": "BM25",
|
||||
},
|
||||
)
|
||||
collection.create_index(vector_field, self.index_params)
|
||||
collection.load()
|
||||
return collection
|
||||
return self.col
|
||||
|
||||
def _load_documents(self, documents) -> List[str]:
|
||||
"""Load documents into Milvus.
|
||||
@@ -418,7 +452,7 @@ class MilvusStore(VectorStoreBase):
|
||||
insert_dict.setdefault("metadata", []).append(metadata_json)
|
||||
insert_dict.setdefault("props_field", []).append(metadata_json)
|
||||
# Convert dict to list of lists for insertion
|
||||
insert_list = [insert_dict[x] for x in self.fields]
|
||||
insert_list = [insert_dict[x] for x in self.fields if self.sparse_vector != x]
|
||||
# Insert into the collection.
|
||||
res = self.col.insert(
|
||||
insert_list, partition_name=partition_name, timeout=timeout
|
||||
@@ -570,13 +604,17 @@ class MilvusStore(VectorStoreBase):
|
||||
self.col.load()
|
||||
# use default index params.
|
||||
if param is None:
|
||||
index_type = self.col.indexes[0].params["index_type"]
|
||||
param = self.index_params_map[index_type]
|
||||
for index in self.col.indexes:
|
||||
if index.params["index_type"] == self.index_params.get("index_type"):
|
||||
param = index.params
|
||||
break
|
||||
# query text embedding.
|
||||
query_vector = self.embedding.embed_query(query)
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self.vector_field)
|
||||
if self.sparse_vector in output_fields:
|
||||
output_fields.remove(self.sparse_vector)
|
||||
# milvus search.
|
||||
res = self.col.search(
|
||||
[query_vector],
|
||||
@@ -595,7 +633,10 @@ class MilvusStore(VectorStoreBase):
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
ret.append(
|
||||
(
|
||||
Chunk(content=meta.pop(self.text_field), metadata=meta),
|
||||
Chunk(
|
||||
content=meta.pop(self.text_field),
|
||||
metadata=json.loads(meta.pop(self.metadata_field)),
|
||||
),
|
||||
result.distance,
|
||||
result.id,
|
||||
)
|
||||
@@ -693,3 +734,51 @@ class MilvusStore(VectorStoreBase):
|
||||
utility.drop_collection(self.collection_name)
|
||||
|
||||
logger.info(f"truncate milvus collection {self.collection_name} success")
|
||||
|
||||
def full_text_search(
|
||||
self, text: str, topk: int = 10, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
if self.is_support_full_text_search():
|
||||
milvus_filters = self.convert_metadata_filters(filters) if filters else None
|
||||
results = self._milvus_client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[text],
|
||||
anns_field=self.sparse_vector,
|
||||
limit=topk,
|
||||
output_fields=["*"],
|
||||
filter=milvus_filters,
|
||||
)
|
||||
chunk_results = [
|
||||
Chunk(
|
||||
content=r.get("entity").get("content"),
|
||||
chunk_id=str(r.get("pk_id")),
|
||||
score=r.get("distance"),
|
||||
metadata=json.loads(r.get("entity").get("metadata")),
|
||||
retriever="full_text",
|
||||
)
|
||||
for r in results[0]
|
||||
]
|
||||
|
||||
return chunk_results
|
||||
|
||||
def is_support_full_text_search(self) -> bool:
|
||||
"""
|
||||
Check Milvus version support full text search.
|
||||
Returns True if the version is >= 2.5.0.
|
||||
"""
|
||||
try:
|
||||
milvus_version_text = self._milvus_client.get_server_version()
|
||||
pattern = r"v(\d+\.\d+\.\d+)"
|
||||
match = re.search(pattern, milvus_version_text)
|
||||
if match:
|
||||
milvus_version = match.group(1)
|
||||
logger.info(f"milvus version is {milvus_version}")
|
||||
# Check if the version is >= 2.5.0
|
||||
return milvus_version >= "2.5.0"
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to check Milvus version:{str(e)}."
|
||||
f"do not support full text index."
|
||||
)
|
||||
return False
|
||||
|
@@ -294,7 +294,7 @@ async def query_page(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get_document_list({}, page, page_size))
|
||||
return Result.succ(service.get_document_list_page({}, page, page_size))
|
||||
|
||||
|
||||
@router.post("/documents/chunks/add")
|
||||
|
@@ -1,17 +1,25 @@
|
||||
import ast
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Document, LLMClient
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
|
||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
from dbgpt_ext.rag.retriever.doc_tree import TreeNode
|
||||
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao
|
||||
from dbgpt_serve.rag.retriever.qa_retriever import QARetriever
|
||||
from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
"""Knowledge Space retriever."""
|
||||
@@ -24,6 +32,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
rerank: Optional[Ranker] = None,
|
||||
llm_model: Optional[str] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
retrieve_mode: Optional[str] = None,
|
||||
system_app: SystemApp = None,
|
||||
):
|
||||
"""
|
||||
@@ -41,6 +50,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
self._llm_model = llm_model
|
||||
app_config = system_app.config.configs.get("app_config")
|
||||
self._top_k = top_k or app_config.rag.similarity_top_k
|
||||
self._retrieve_mode = retrieve_mode or RetrieverStrategy.HYBRID.value
|
||||
self._embedding_model = embedding_model or app_config.models.default_embedding
|
||||
self._system_app = system_app
|
||||
embedding_factory = self._system_app.get_component(
|
||||
@@ -49,14 +59,14 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
embedding_fn = embedding_factory.create()
|
||||
|
||||
space_dao = KnowledgeSpaceDao()
|
||||
space = space_dao.get_one({"id": space_id})
|
||||
if space is None:
|
||||
space = space_dao.get_one({"name": space_id})
|
||||
if space is None:
|
||||
self._space = space_dao.get_one({"id": space_id})
|
||||
if self._space is None:
|
||||
self._space = space_dao.get_one({"name": space_id})
|
||||
if self._space is None:
|
||||
raise ValueError(f"Knowledge space {space_id} not found")
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
space.name,
|
||||
space.vector_type,
|
||||
self._storage_connector = self.storage_manager.get_storage_connector(
|
||||
self._space.name,
|
||||
self._space.vector_type,
|
||||
self._llm_model,
|
||||
)
|
||||
self._executor = self._system_app.get_component(
|
||||
@@ -72,7 +82,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
system_app=system_app,
|
||||
),
|
||||
EmbeddingRetriever(
|
||||
index_store=storage_connector,
|
||||
index_store=self._storage_connector,
|
||||
top_k=self._top_k,
|
||||
query_rewrite=self._query_rewrite,
|
||||
rerank=self._rerank,
|
||||
@@ -85,6 +95,19 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
def storage_manager(self):
|
||||
return StorageManager.get_instance(self._system_app)
|
||||
|
||||
@property
|
||||
def rag_service(self):
|
||||
from dbgpt_serve.rag.service.service import Service as RagService
|
||||
|
||||
return RagService.get_instance(self._system_app)
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
worker_manager = self._system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return DefaultLLMClient(worker_manager, True)
|
||||
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
@@ -133,9 +156,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
candidates = await blocking_func_to_async(
|
||||
self._executor, self._retrieve, query, filters
|
||||
)
|
||||
candidates = await self._aretrieve_with_score(query, 0.0, filters)
|
||||
return candidates
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
@@ -146,6 +167,54 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text.
|
||||
score_threshold (float): score threshold.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score.
|
||||
"""
|
||||
if self._retrieve_mode == RetrieverStrategy.SEMANTIC.value:
|
||||
logger.info("Starting Semantic retrieval")
|
||||
return await self.semantic_retrieve(query, score_threshold, filters)
|
||||
elif self._retrieve_mode == RetrieverStrategy.KEYWORD.value:
|
||||
logger.info("Starting Full Text retrieval")
|
||||
return await self.full_text_retrieve(query, self._top_k, filters)
|
||||
elif self._retrieve_mode == RetrieverStrategy.Tree.value:
|
||||
logger.info("Starting Doc Tree retrieval")
|
||||
return await self.tree_index_retrieve(query, self._top_k, filters)
|
||||
elif self._retrieve_mode == RetrieverStrategy.HYBRID.value:
|
||||
logger.info("Starting Hybrid retrieval")
|
||||
tasks = []
|
||||
import asyncio
|
||||
|
||||
tasks.append(self.semantic_retrieve(query, score_threshold, filters))
|
||||
tasks.append(self.full_text_retrieve(query, self._top_k, filters))
|
||||
tasks.append(self.tree_index_retrieve(query, self._top_k, filters))
|
||||
results = await asyncio.gather(*tasks)
|
||||
semantic_candidates = results[0]
|
||||
full_text_candidates = results[1]
|
||||
tree_candidates = results[2]
|
||||
logger.info(
|
||||
f"Hybrid retrieval completed. "
|
||||
f"Found {len(semantic_candidates)} semantic candidates "
|
||||
f"and Found {len(full_text_candidates)} full text candidates."
|
||||
f"and Found {len(tree_candidates)} tree candidates."
|
||||
)
|
||||
candidates = semantic_candidates + full_text_candidates + tree_candidates
|
||||
# Remove duplicates
|
||||
unique_candidates = {chunk.content: chunk for chunk in candidates}
|
||||
return list(unique_candidates.values())
|
||||
|
||||
async def semantic_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text.
|
||||
score_threshold (float): score threshold.
|
||||
@@ -157,3 +226,110 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
return await self._retriever_chain.aretrieve_with_scores(
|
||||
query, score_threshold, filters
|
||||
)
|
||||
|
||||
async def full_text_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Full Text Retrieve knowledge chunks with score.
|
||||
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/
|
||||
index-modules-similarity.html;
|
||||
TF/IDF or BM25 based similarity that has built-in tf normalization and is
|
||||
supposed to work better for short fields (like names).
|
||||
See Okapi_BM25 for more details.
|
||||
|
||||
Args:
|
||||
query (str): query text.
|
||||
top_k (int): top k limit.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score.
|
||||
"""
|
||||
if self._storage_connector.is_support_full_text_search():
|
||||
return await self._storage_connector.afull_text_search(
|
||||
query, top_k, filters
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Full text search is not supported for this storage connector."
|
||||
)
|
||||
return []
|
||||
|
||||
async def tree_index_retrieve(
|
||||
self, query: str, top_k: int, filters: Optional[MetadataFilters] = None
|
||||
):
|
||||
"""Search for keywords in the tree index."""
|
||||
# Check if the keyword is in the node title
|
||||
# If the node has children, recursively search in them
|
||||
# If the node is a leaf, check if it contains the keyword
|
||||
try:
|
||||
docs_res = self.rag_service.get_document_list(
|
||||
{
|
||||
"space": self._space.name,
|
||||
}
|
||||
)
|
||||
docs = []
|
||||
for doc_res in docs_res:
|
||||
doc = Document(
|
||||
content=doc_res.content,
|
||||
)
|
||||
chunks_res = self.rag_service.get_chunk_list(
|
||||
{
|
||||
"document_id": doc_res.id,
|
||||
}
|
||||
)
|
||||
chunks = [
|
||||
Chunk(
|
||||
content=chunk_res.content,
|
||||
metadata=ast.literal_eval(chunk_res.meta_info),
|
||||
)
|
||||
for chunk_res in chunks_res
|
||||
]
|
||||
doc.chunks = chunks
|
||||
docs.append(doc)
|
||||
keyword_extractor = KeywordExtractor(
|
||||
llm_client=self.llm_client, model_name=self._llm_model
|
||||
)
|
||||
from dbgpt_ext.rag.retriever.doc_tree import DocTreeRetriever
|
||||
|
||||
tree_retriever = DocTreeRetriever(
|
||||
docs=docs,
|
||||
keywords_extractor=keyword_extractor,
|
||||
top_k=self._top_k,
|
||||
query_rewrite=self._query_rewrite,
|
||||
with_content=True,
|
||||
rerank=self._rerank,
|
||||
)
|
||||
candidates = []
|
||||
tree_nodes = await tree_retriever.aretrieve_with_scores(
|
||||
query, top_k, filters
|
||||
)
|
||||
# Convert tree nodes to chunks
|
||||
for node in tree_nodes:
|
||||
chunks = self._traverse(node)
|
||||
candidates.extend(chunks)
|
||||
return candidates
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tree index retrieval: {e}")
|
||||
return []
|
||||
|
||||
def _traverse(self, node: TreeNode):
|
||||
"""Traverse the tree and search for the keyword."""
|
||||
# Check if the node has children
|
||||
result = []
|
||||
if node.children:
|
||||
for child in node.children:
|
||||
result.extend(self._traverse(child))
|
||||
else:
|
||||
# If the node is a leaf, check if it contains the keyword
|
||||
if node:
|
||||
result.append(
|
||||
Chunk(
|
||||
content=node.content,
|
||||
retriever=node.retriever,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
@@ -397,6 +397,21 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
return self.dao.get_list_page(request, page, page_size)
|
||||
|
||||
def get_document_list(
|
||||
self, request: QUERY_SPEC
|
||||
) -> PaginationResult[DocumentServeResponse]:
|
||||
"""Get a list of Flow entities by page
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
|
||||
Returns:
|
||||
List[SpaceServeResponse]: The response
|
||||
"""
|
||||
return self._document_dao.get_list(request)
|
||||
|
||||
def get_document_list_page(
|
||||
self, request: QUERY_SPEC, page: int, page_size: int
|
||||
) -> PaginationResult[DocumentServeResponse]:
|
||||
"""Get a list of Flow entities by page
|
||||
|
Reference in New Issue
Block a user