mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-04 06:37:58 +00:00
153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
import os
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
|
|
class SearchDepth(Enum):
|
|
"""Search depth as enumerator."""
|
|
|
|
BASIC = "basic"
|
|
ADVANCED = "advanced"
|
|
|
|
|
|
class TavilySearchAPIRetriever(BaseRetriever):
|
|
"""Tavily Search API retriever.
|
|
|
|
Setup:
|
|
Install ``langchain-community`` and set environment variable ``TAVILY_API_KEY``.
|
|
|
|
.. code-block:: bash
|
|
|
|
pip install -U langchain-community
|
|
export TAVILY_API_KEY="your-api-key"
|
|
|
|
Key init args:
|
|
k: int
|
|
Number of results to include.
|
|
include_generated_answer: bool
|
|
Include a generated answer with results
|
|
include_raw_content: bool
|
|
Include raw content with results.
|
|
include_images: bool
|
|
Return images in addition to text.
|
|
|
|
Instantiate:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.retrievers import TavilySearchAPIRetriever
|
|
|
|
retriever = TavilySearchAPIRetriever(k=3)
|
|
|
|
Usage:
|
|
.. code-block:: python
|
|
|
|
query = "what year was breath of the wild released?"
|
|
|
|
retriever.invoke(query)
|
|
|
|
Use within a chain:
|
|
.. code-block:: python
|
|
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
prompt = ChatPromptTemplate.from_template(
|
|
\"\"\"Answer the question based only on the context provided.
|
|
|
|
Context: {context}
|
|
|
|
Question: {question}\"\"\"
|
|
)
|
|
|
|
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
|
|
|
def format_docs(docs):
|
|
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
|
chain = (
|
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
| prompt
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
chain.invoke("how many units did bretch of the wild sell in 2020")
|
|
|
|
""" # noqa: E501
|
|
|
|
k: int = 10
|
|
include_generated_answer: bool = False
|
|
include_raw_content: bool = False
|
|
include_images: bool = False
|
|
search_depth: SearchDepth = SearchDepth.BASIC
|
|
include_domains: Optional[List[str]] = None
|
|
exclude_domains: Optional[List[str]] = None
|
|
kwargs: Optional[Dict[str, Any]] = {}
|
|
api_key: Optional[str] = None
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
try:
|
|
try:
|
|
from tavily import TavilyClient
|
|
except ImportError:
|
|
# Older of tavily used Client
|
|
from tavily import Client as TavilyClient
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Tavily python package not found. "
|
|
"Please install it with `pip install tavily-python`."
|
|
)
|
|
|
|
tavily = TavilyClient(api_key=self.api_key or os.environ["TAVILY_API_KEY"])
|
|
max_results = self.k if not self.include_generated_answer else self.k - 1
|
|
response = tavily.search(
|
|
query=query,
|
|
max_results=max_results,
|
|
search_depth=self.search_depth.value,
|
|
include_answer=self.include_generated_answer,
|
|
include_domains=self.include_domains,
|
|
exclude_domains=self.exclude_domains,
|
|
include_raw_content=self.include_raw_content,
|
|
include_images=self.include_images,
|
|
**self.kwargs,
|
|
)
|
|
docs = [
|
|
Document(
|
|
page_content=result.get("content", "")
|
|
if not self.include_raw_content
|
|
else result.get("raw_content", ""),
|
|
metadata={
|
|
"title": result.get("title", ""),
|
|
"source": result.get("url", ""),
|
|
**{
|
|
k: v
|
|
for k, v in result.items()
|
|
if k not in ("content", "title", "url", "raw_content")
|
|
},
|
|
"images": response.get("images"),
|
|
},
|
|
)
|
|
for result in response.get("results")
|
|
]
|
|
if self.include_generated_answer:
|
|
docs = [
|
|
Document(
|
|
page_content=response.get("answer", ""),
|
|
metadata={
|
|
"title": "Suggested Answer",
|
|
"source": "https://tavily.com/",
|
|
},
|
|
),
|
|
*docs,
|
|
]
|
|
|
|
return docs
|