mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Literal
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from pydantic import Field, SecretStr, model_validator
|
|
|
|
from langchain_perplexity._utils import initialize_client
|
|
|
|
|
|
class PerplexitySearchRetriever(BaseRetriever):
|
|
"""Perplexity Search retriever."""
|
|
|
|
k: int = Field(default=10, description="Max results (1-20)")
|
|
max_tokens: int = Field(default=25000, description="Max tokens across all results")
|
|
max_tokens_per_page: int = Field(default=1024, description="Max tokens per page")
|
|
country: str | None = Field(default=None, description="ISO country code")
|
|
search_domain_filter: list[str] | None = Field(
|
|
default=None, description="Domain filter (max 20)"
|
|
)
|
|
search_recency_filter: Literal["day", "week", "month", "year"] | None = None
|
|
search_after_date: str | None = Field(
|
|
default=None, description="Date filter (format: %m/%d/%Y)"
|
|
)
|
|
search_before_date: str | None = Field(
|
|
default=None, description="Date filter (format: %m/%d/%Y)"
|
|
)
|
|
|
|
client: Any = Field(default=None, exclude=True)
|
|
pplx_api_key: SecretStr = Field(default=SecretStr(""))
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def validate_environment(cls, values: dict) -> Any:
|
|
"""Validate the environment."""
|
|
return initialize_client(values)
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> list[Document]:
|
|
params = {
|
|
"query": query,
|
|
"max_results": self.k,
|
|
"max_tokens": self.max_tokens,
|
|
"max_tokens_per_page": self.max_tokens_per_page,
|
|
"country": self.country,
|
|
"search_domain_filter": self.search_domain_filter,
|
|
"search_recency_filter": self.search_recency_filter,
|
|
"search_after_date": self.search_after_date,
|
|
"search_before_date": self.search_before_date,
|
|
}
|
|
params = {k: v for k, v in params.items() if v is not None}
|
|
response = self.client.search.create(**params)
|
|
|
|
return [
|
|
Document(
|
|
page_content=result.snippet,
|
|
metadata={
|
|
"title": result.title,
|
|
"url": result.url,
|
|
"date": result.date,
|
|
"last_updated": result.last_updated,
|
|
},
|
|
)
|
|
for result in response.results
|
|
]
|