mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
33 lines
976 B
Python
33 lines
976 B
Python
from typing import Any, Type
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
from langchain_tests.integration_tests import RetrieversIntegrationTests
|
|
|
|
|
|
class ParrotRetriever(BaseRetriever):
|
|
parrot_name: str
|
|
k: int = 3
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
|
) -> list[Document]:
|
|
k = kwargs.get("k", self.k)
|
|
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
|
|
|
|
|
|
class TestParrotRetrieverIntegration(RetrieversIntegrationTests):
|
|
@property
|
|
def retriever_constructor(self) -> Type[ParrotRetriever]:
|
|
return ParrotRetriever
|
|
|
|
@property
|
|
def retriever_constructor_params(self) -> dict:
|
|
return {"parrot_name": "Polly"}
|
|
|
|
@property
|
|
def retriever_query_example(self) -> str:
|
|
return "parrot"
|