Files
langchain/libs/partners/perplexity/langchain_perplexity/retrievers.py

69 lines
2.5 KiB
Python

from __future__ import annotations
from typing import Any, Literal
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field, SecretStr, model_validator
from langchain_perplexity._utils import initialize_client
class PerplexitySearchRetriever(BaseRetriever):
"""Perplexity Search retriever."""
k: int = Field(default=10, description="Max results (1-20)")
max_tokens: int = Field(default=25000, description="Max tokens across all results")
max_tokens_per_page: int = Field(default=1024, description="Max tokens per page")
country: str | None = Field(default=None, description="ISO country code")
search_domain_filter: list[str] | None = Field(
default=None, description="Domain filter (max 20)"
)
search_recency_filter: Literal["day", "week", "month", "year"] | None = None
search_after_date: str | None = Field(
default=None, description="Date filter (format: %m/%d/%Y)"
)
search_before_date: str | None = Field(
default=None, description="Date filter (format: %m/%d/%Y)"
)
client: Any = Field(default=None, exclude=True)
pplx_api_key: SecretStr = Field(default=SecretStr(""))
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: dict) -> Any:
"""Validate the environment."""
return initialize_client(values)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
params = {
"query": query,
"max_results": self.k,
"max_tokens": self.max_tokens,
"max_tokens_per_page": self.max_tokens_per_page,
"country": self.country,
"search_domain_filter": self.search_domain_filter,
"search_recency_filter": self.search_recency_filter,
"search_after_date": self.search_after_date,
"search_before_date": self.search_before_date,
}
params = {k: v for k, v in params.items() if v is not None}
response = self.client.search.create(**params)
return [
Document(
page_content=result.snippet,
metadata={
"title": result.title,
"url": result.url,
"date": result.date,
"last_updated": result.last_updated,
},
)
for result in response.results
]