mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 01:49:58 +00:00
feat(rag): Support rag retriever evaluation (#1291)
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
"""Embedding retriever operator."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import Any, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
|
||||
def __init__(
|
||||
@@ -32,7 +33,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
rerank=rerank,
|
||||
)
|
||||
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
def retrieve(self, query: Union[str, List[str]]) -> List[Chunk]:
|
||||
"""Retrieve the candidates."""
|
||||
if isinstance(query, str):
|
||||
return self._retriever.retrieve_with_scores(query, self._score_threshold)
|
||||
|
Reference in New Issue
Block a user