diff --git a/libs/langchain/langchain/retrievers/you.py b/libs/langchain/langchain/retrievers/you.py index 3c281c3de23..2fe0e90a3e8 100644 --- a/libs/langchain/langchain/retrievers/you.py +++ b/libs/langchain/langchain/retrievers/you.py @@ -17,6 +17,7 @@ class YouRetriever(BaseRetriever): """ ydc_api_key: str + endpoint_type: str = "web" @root_validator(pre=True) def validate_client( @@ -34,13 +35,22 @@ class YouRetriever(BaseRetriever): import requests headers = {"X-API-Key": self.ydc_api_key} - results = requests.get( - f"https://api.ydc-index.io/search?query={query}", - headers=headers, - ).json() + if self.endpoint_type == "web": + results = requests.get( + f"https://api.ydc-index.io/search?query={query}", + headers=headers, + ).json() - docs = [] - for hit in results["hits"]: - for snippet in hit["snippets"]: - docs.append(Document(page_content=snippet)) - return docs + docs = [] + for hit in results["hits"]: + for snippet in hit["snippets"]: + docs.append(Document(page_content=snippet)) + return docs + elif self.endpoint_type == "snippet": + results = requests.get( + f"https://api.ydc-index.io/snippet_search?query={query}", + headers=headers, + ).json() + return [Document(page_content=snippet) for snippet in results] + else: + raise RuntimeError(f"Invalid endpoint type provided {self.endpoint_type}")