mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-11-03 08:58:29 +00:00
chore: Add pylint for DB-GPT rag lib (#1267)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Embedding retriever operator."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
@@ -10,25 +11,29 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = 0.3,
|
||||
score_threshold: float = 0.3,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Ranker = None,
|
||||
vector_store_connector: VectorStoreConnector = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new EmbeddingRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._score_threshold = score_threshold
|
||||
self._retriever = EmbeddingRetriever(
|
||||
vector_store_connector=vector_store_connector,
|
||||
top_k=top_k,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=rerank,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: IN) -> Any:
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
"""Retrieve the candidates."""
|
||||
if isinstance(query, str):
|
||||
return self._retriever.retrieve_with_scores(query, self._score_threshold)
|
||||
elif isinstance(query, list):
|
||||
|
||||
Reference in New Issue
Block a user