mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 11:41:51 +00:00
656 lines
102 KiB
Plaintext
656 lines
102 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9e7a7c86",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Custom RAG Agent Workflow with Open Source LLMs Running Locally on Intel CPU"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f309f56d-1db4-4e03-870e-a2a6f5ee4dc5",
|
|
"metadata": {},
|
|
"source": [
|
|
"Author - Pratool Bharti (pratool.bharti@intel.com)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0af01c3c-c42a-4ba5-95fa-4b83fd77fe9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"This notebook demonstrates a Retrieval-Augmented Generation (RAG) agent that routes questions through two paths to find answers. The agent generates answers based on documents retrieved from either the vector database or web search. If the vector database lacks relevant information, the agent opts for web search. Open-source models for LLM and embeddings are used locally on an Intel Xeon CPU to execute this pipeline."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8b50e68f",
|
|
"metadata": {},
|
|
"source": [
|
|
"<figure style=\"text-align: center;\">\n",
|
|
"<figcaption style=\"text-align: center;\">Flow chart for the Custom RAG Agent Workflow</figcaption>\n",
|
|
"<img src=\"\" />"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "24f76969",
|
|
"metadata": {},
|
|
"source": [
|
|
"Install required libraries in a conda or venv environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "746ae008",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
|
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install --upgrade --quiet tiktoken scikit-learn gpt4all langchain langchain-community langchain-core langchain_nomic langchain_ollama langgraph "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "399f7e2e",
|
|
"metadata": {},
|
|
"source": [
|
|
"In Linux systems, use following commands to install Ollama and download Llama 3.1 model locally.\n",
|
|
"```\n",
|
|
"curl -fsSL https://ollama.com/install.sh | sh\n",
|
|
"ollama run llama3.1\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c7ea62fe-7ea0-4e98-95e5-df79599b1545",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"This cell asks you to set up environment variables for a local RAG (Retrieval-Augmented Generation) agent.\n",
|
|
"\n",
|
|
"Environment Variables:\n",
|
|
"- USER_AGENT: Specifies the user agent string to be used.\n",
|
|
"- LANGSMITH_TRACING: Enables or disables tracing for LangChain.\n",
|
|
"- LANGSMITH_API_KEY: API key for accessing LangChain services.\n",
|
|
"- TAVILY_API_KEY: API key for accessing Tavily services.\n",
|
|
"\"\"\"\n",
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ[\"USER_AGENT\"] = \"myagent\"\n",
|
|
"os.environ[\"LANGSMITH_TRACING\"] = \"true\"\n",
|
|
"os.environ[\"LANGSMITH_API_KEY\"] = \"xxxx\"\n",
|
|
"os.environ[\"TAVILY_API_KEY\"] = \"tvly-xxxx\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f4fe714b",
|
|
"metadata": {},
|
|
"source": [
|
|
"Use local embedding model to store documents in vector database"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8d1b3be3-b150-4e39-aecf-f4a51a5eb358",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Failed to load libllamamodel-mainline-cuda-avxonly.so: dlopen: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
|
|
"Failed to load libllamamodel-mainline-cuda.so: dlopen: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"This cell performs the following tasks:\n",
|
|
"\n",
|
|
"1. Imports necessary modules and classes from langchain and related libraries.\n",
|
|
"2. Defines a list of URLs from IRS to load tax related documents from.\n",
|
|
"3. Loads documents from the specified URLs using the WebBaseLoader.\n",
|
|
"4. Flattens the list of loaded documents.\n",
|
|
"5. Initializes a RecursiveCharacterTextSplitter with a specified chunk size and overlap.\n",
|
|
"6. Splits the loaded documents into chunks using the text splitter.\n",
|
|
"7. Initializes an SKLearnVectorStore with the document chunks embedded using local embeddings model \"nomic-embed-text-v1.5\" from NomicEmbeddings.\n",
|
|
"8. Converts the vector store into a retriever with a specified number of nearest neighbors (k=4).\n",
|
|
"\n",
|
|
"Modules and Classes:\n",
|
|
"- RecursiveCharacterTextSplitter: Splits text into chunks based on character count.\n",
|
|
"- WebBaseLoader: Loads documents from web URLs.\n",
|
|
"- SKLearnVectorStore: Stores document vectors for retrieval.\n",
|
|
"- NomicEmbeddings: Generates embeddings for documents.\n",
|
|
"- tool: Utility for defining tools.\n",
|
|
"\n",
|
|
"Variables:\n",
|
|
"- urls: List of URLs to load documents from.\n",
|
|
"- docs: List of loaded documents from the URLs.\n",
|
|
"- docs_list: Flattened list of loaded documents.\n",
|
|
"- text_splitter: Instance of RecursiveCharacterTextSplitter.\n",
|
|
"- doc_splits: List of document chunks.\n",
|
|
"- vectorstore: Instance of SKLearnVectorStore.\n",
|
|
"- retriever: Retriever instance for querying the vector store.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
|
"from langchain_community.document_loaders import WebBaseLoader\n",
|
|
"from langchain_community.vectorstores import SKLearnVectorStore\n",
|
|
"from langchain_core.tools import tool\n",
|
|
"from langchain_nomic.embeddings import NomicEmbeddings\n",
|
|
"\n",
|
|
"# List of URLs to load documents from\n",
|
|
"urls = [\n",
|
|
" \"https://www.irs.gov/newsroom/irs-releases-tax-inflation-adjustments-for-tax-year-2025\",\n",
|
|
" \"https://www.irs.gov/newsroom/401k-limit-increases-to-23500-for-2025-ira-limit-remains-7000\",\n",
|
|
" \"https://www.irs.gov/newsroom/tax-basics-understanding-the-difference-between-standard-and-itemized-deductions\",\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Load documents from the URLs\n",
|
|
"docs = [WebBaseLoader(url).load() for url in urls]\n",
|
|
"docs_list = [item for sublist in docs for item in sublist]\n",
|
|
"\n",
|
|
"# Initialize a text splitter with specified chunk size and overlap\n",
|
|
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
|
" chunk_size=250, chunk_overlap=0\n",
|
|
")\n",
|
|
"\n",
|
|
"# Split the documents into chunks\n",
|
|
"doc_splits = text_splitter.split_documents(docs_list)\n",
|
|
"\n",
|
|
"# Add the document chunks to the \"vector store\" using NomicEmbeddings\n",
|
|
"vectorstore = SKLearnVectorStore.from_documents(\n",
|
|
" documents=doc_splits,\n",
|
|
" embedding=NomicEmbeddings(\n",
|
|
" model=\"nomic-embed-text-v1.5\", inference_mode=\"local\", device=\"cpu\"\n",
|
|
" ),\n",
|
|
" # embedding=OpenAIEmbeddings(),\n",
|
|
")\n",
|
|
"retriever = vectorstore.as_retriever(k=4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "f8d54464-37b9-4b48-877e-38fc7620c1ff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"This cell imports the necessary modules and initializes the web search tool for the LLM.\n",
|
|
"\n",
|
|
"Modules:\n",
|
|
"- `Document` from `langchain.schema`: Represents a document schema.\n",
|
|
"- `TavilySearchResults` from `langchain_community.tools.tavily_search`: Provides functionality to perform web search by LLM if required.\n",
|
|
"\n",
|
|
"Initialization:\n",
|
|
"- `web_search_tool`: An instance of `TavilySearchResults` used to perform web searches.\n",
|
|
"\"\"\"\n",
|
|
"from langchain.schema import Document\n",
|
|
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
|
|
"\n",
|
|
"web_search_tool = TavilySearchResults()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "36dad7e6-3752-4939-be70-f87d23d90d6f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"This cell sets up a question-answering assistant using the LangChain library. \n",
|
|
"1. It imports necessary modules: `ChatOllama` for the language model, `PromptTemplate` for creating prompts, and `StrOutputParser` for parsing the output.\n",
|
|
"2. It defines a prompt template that instructs the assistant to answer questions concisely using provided documents.\n",
|
|
"3. It initializes the `ChatOllama` language model with specific parameters.\n",
|
|
"4. It creates a chain (`rag_chain`) that combines the prompt template, language model, and output parser to process and generate answers.\n",
|
|
"This setup is essential for enabling the assistant to handle question-answering tasks effectively.\n",
|
|
"\"\"\"\n",
|
|
"from langchain.prompts import PromptTemplate\n",
|
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
|
"from langchain_ollama import ChatOllama\n",
|
|
"\n",
|
|
"prompt = PromptTemplate(\n",
|
|
" template=\"\"\"You are an assistant for question-answering tasks. \n",
|
|
" \n",
|
|
" Use the following documents to answer the question. \n",
|
|
" \n",
|
|
" If you don't know the answer, just say that you don't know. \n",
|
|
" \n",
|
|
" Use three sentences maximum and keep the answer concise:\n",
|
|
" Question: {question} \n",
|
|
" Documents: {documents} \n",
|
|
" Answer: \n",
|
|
" \"\"\",\n",
|
|
" input_variables=[\"question\", \"documents\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"llm = ChatOllama(\n",
|
|
" model=\"llama3.1\",\n",
|
|
" temperature=0,\n",
|
|
")\n",
|
|
"\n",
|
|
"rag_chain = prompt | llm | StrOutputParser()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "0affbee8-30c4-4dd0-a95a-d8ab571b55c6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"This cell sets up a prompt template and a retrieval grader for assessing the relevance of a retrieved document to a user question.\n",
|
|
"\n",
|
|
"Functionality:\n",
|
|
"- Imports the necessary JsonOutputParser from langchain_core.output_parsers.\n",
|
|
"- Defines a PromptTemplate that instructs a grader to assess the relevance of a document to a user question.\n",
|
|
"- The grader uses a simple binary scoring system ('yes' or 'no') to indicate relevance.\n",
|
|
"- The result is provided as a JSON object with a single key 'score'.\n",
|
|
"- Combines the prompt template with a language model (llm) and the JsonOutputParser to create the retrieval_grader.\n",
|
|
"\n",
|
|
"The retrieval_grader can be used in the workflow to filter out erroneous document retrievals based on their relevance to user questions.\n",
|
|
"\"\"\"\n",
|
|
"from langchain_core.output_parsers import JsonOutputParser\n",
|
|
"\n",
|
|
"prompt = PromptTemplate(\n",
|
|
" template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
|
|
" Here is the retrieved document: \\n\\n {document} \\n\\n\n",
|
|
" Here is the user question: {question} \\n\n",
|
|
" If the document contains keywords related to the user question, grade it as relevant. \\n\n",
|
|
" It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n",
|
|
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \\n\n",
|
|
" Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\"\"\",\n",
|
|
" input_variables=[\"question\", \"document\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"retrieval_grader = prompt | llm | JsonOutputParser()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "d672ffdf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# This cell defines the state of the graph and imports necessary modules for graph visualization.\n",
|
|
"# It includes a TypedDict class `GraphState` that represents the state of the graph with attributes\n",
|
|
"# such as question, generation, search, documents, and steps. This state will be used to manage\n",
|
|
"# the workflow of the RAG agent.\n",
|
|
"\n",
|
|
"from IPython.display import Image, display\n",
|
|
"from langgraph.graph import END, START, StateGraph\n",
|
|
"from typing_extensions import List, TypedDict\n",
|
|
"\n",
|
|
"\n",
|
|
"class GraphState(TypedDict):\n",
|
|
" \"\"\"\n",
|
|
" Represents the state of our graph.\n",
|
|
"\n",
|
|
" Attributes:\n",
|
|
" question: question\n",
|
|
" generation: LLM generation\n",
|
|
" search: whether to add search\n",
|
|
" documents: list of documents\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" question: str\n",
|
|
" generation: str\n",
|
|
" search: str\n",
|
|
" documents: List[str]\n",
|
|
" steps: List[str]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "2f26efee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# This cell contains the core functions for the document retrieval and answer generation pipeline.\n",
|
|
"# The functions are designed to work with a state dictionary that maintains the current state of the process.\n",
|
|
"\n",
|
|
"\n",
|
|
"def retrieve(state):\n",
|
|
" \"\"\"\n",
|
|
" Retrieve documents\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" state (dict): The current graph state\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" state (dict): New key added to state, documents, that contains retrieved documents\n",
|
|
" \"\"\"\n",
|
|
" question = state[\"question\"]\n",
|
|
" documents = retriever.invoke(question)\n",
|
|
" steps = state[\"steps\"]\n",
|
|
" steps.append(\"retrieve_documents\")\n",
|
|
" return {\"documents\": documents, \"question\": question, \"steps\": steps}\n",
|
|
"\n",
|
|
"\n",
|
|
"def generate(state):\n",
|
|
" \"\"\"\n",
|
|
" Generate answer\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" state (dict): The current graph state\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" state (dict): New key added to state, generation, that contains LLM generation\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" question = state[\"question\"]\n",
|
|
" documents = state[\"documents\"]\n",
|
|
" generation = rag_chain.invoke({\"documents\": documents, \"question\": question})\n",
|
|
" steps = state[\"steps\"]\n",
|
|
" steps.append(\"generate_answer\")\n",
|
|
" return {\n",
|
|
" \"documents\": documents,\n",
|
|
" \"question\": question,\n",
|
|
" \"generation\": generation,\n",
|
|
" \"steps\": steps,\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def grade_documents(state):\n",
|
|
" \"\"\"\n",
|
|
" Determines whether the retrieved documents are relevant to the question.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" state (dict): The current graph state\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" state (dict): Updates documents key with only filtered relevant documents\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" question = state[\"question\"]\n",
|
|
" documents = state[\"documents\"]\n",
|
|
" steps = state[\"steps\"]\n",
|
|
" steps.append(\"grade_document_retrieval\")\n",
|
|
" filtered_docs = []\n",
|
|
" search = \"No\"\n",
|
|
" for d in documents:\n",
|
|
" score = retrieval_grader.invoke(\n",
|
|
" {\"question\": question, \"document\": d.page_content}\n",
|
|
" )\n",
|
|
" grade = score[\"score\"]\n",
|
|
" if grade == \"yes\":\n",
|
|
" filtered_docs.append(d)\n",
|
|
" else:\n",
|
|
" search = \"Yes\"\n",
|
|
" continue\n",
|
|
" return {\n",
|
|
" \"documents\": filtered_docs,\n",
|
|
" \"question\": question,\n",
|
|
" \"search\": search,\n",
|
|
" \"steps\": steps,\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def web_search(state):\n",
|
|
" \"\"\"\n",
|
|
" Web search based on the re-phrased question.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" state (dict): The current graph state\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" state (dict): Updates documents key with appended web results\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" question = state[\"question\"]\n",
|
|
" documents = state.get(\"documents\", [])\n",
|
|
" steps = state[\"steps\"]\n",
|
|
" steps.append(\"web_search\")\n",
|
|
" web_results = web_search_tool.invoke({\"query\": question})\n",
|
|
" documents.extend(\n",
|
|
" [\n",
|
|
" Document(page_content=d[\"content\"], metadata={\"url\": d[\"url\"]})\n",
|
|
" for d in web_results\n",
|
|
" ]\n",
|
|
" )\n",
|
|
" return {\"documents\": documents, \"question\": question, \"steps\": steps}\n",
|
|
"\n",
|
|
"\n",
|
|
"def decide_to_generate(state):\n",
|
|
" \"\"\"\n",
|
|
" Determines whether to generate an answer, or re-generate a question.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" state (dict): The current graph state\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" str: Binary decision for next node to call\n",
|
|
" \"\"\"\n",
|
|
" search = state[\"search\"]\n",
|
|
" if search == \"Yes\":\n",
|
|
" return \"search\"\n",
|
|
" else:\n",
|
|
" return \"generate\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "e056c4c8-fb62-4524-bb38-11b8c2a20326",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<IPython.core.display.Image object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Graph\n",
|
|
"\"\"\"\n",
|
|
"This cell defines and builds a state graph workflow for the agent pipeline described earlier.\n",
|
|
"\n",
|
|
"The workflow consists of the following nodes:\n",
|
|
"- \"retrieve\": Retrieves documents from the vector database.\n",
|
|
"- \"grade_documents\": Grades the retrieved documents.\n",
|
|
"- \"generate\": Generates output based on the graded documents.\n",
|
|
"- \"web_search\": Performs a web search if needed.\n",
|
|
"\n",
|
|
"The workflow is constructed as follows:\n",
|
|
"1. The entry point is set to the \"retrieve\" node. so the first step is to retrieve similar documents from the vector database.\n",
|
|
"2. An edge is added from \"retrieve\" to \"grade_documents\".\n",
|
|
"3. Conditional edges are added from \"grade_documents\" to either \"web_search\" or \"generate\" based on the decision function `decide_to_generate`.\n",
|
|
"4. An edge is added from \"web_search\" to \"generate\".\n",
|
|
"5. An edge is added from \"generate\" to the end of the workflow.\n",
|
|
"\n",
|
|
"Finally, the workflow is compiled into a custom graph and displayed as a Mermaid diagram.\n",
|
|
"\"\"\"\n",
|
|
"workflow = StateGraph(GraphState)\n",
|
|
"\n",
|
|
"# Define the nodes\n",
|
|
"workflow.add_node(\"retrieve\", retrieve) # retrieve\n",
|
|
"workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n",
|
|
"workflow.add_node(\"generate\", generate) # generate\n",
|
|
"workflow.add_node(\"web_search\", web_search) # web search\n",
|
|
"\n",
|
|
"# Build graph\n",
|
|
"workflow.set_entry_point(\"retrieve\")\n",
|
|
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
|
|
"workflow.add_conditional_edges(\n",
|
|
" \"grade_documents\",\n",
|
|
" decide_to_generate,\n",
|
|
" {\"search\": \"web_search\", \"generate\": \"generate\"},\n",
|
|
")\n",
|
|
"workflow.add_edge(\"web_search\", \"generate\")\n",
|
|
"workflow.add_edge(\"generate\", END)\n",
|
|
"\n",
|
|
"custom_graph = workflow.compile()\n",
|
|
"\n",
|
|
"display(Image(custom_graph.get_graph(xray=True).draw_mermaid_png()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "f26919fb-85ac-4afc-aaf7-cbb222dcd737",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import uuid\n",
|
|
"\n",
|
|
"\n",
|
|
"def predict_custom_agent_answer(example: dict):\n",
|
|
" # This cell defines a function to predict the answer from a custom agent based on the provided example input.\n",
|
|
" \"\"\"\n",
|
|
" Predicts the answer from a custom agent based on the provided example input.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" example (dict): A dictionary containing the input question under the key \"input\".\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" dict: A dictionary containing the response generated by the custom agent under the key \"response\",\n",
|
|
" and the steps taken during the generation process under the key \"steps\".\n",
|
|
"\n",
|
|
" The `config` dictionary is used to pass configuration settings to the custom graph.\n",
|
|
" In this case, it includes a unique `thread_id` generated using `uuid.uuid4()`.\n",
|
|
" The `thread_id` ensures that each invocation of the function is uniquely identifiable,\n",
|
|
" which can be useful for tracing and debugging purposes.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" config = {\"configurable\": {\"thread_id\": str(uuid.uuid4())}}\n",
|
|
"\n",
|
|
" state_dict = custom_graph.invoke(\n",
|
|
" {\"question\": example[\"input\"], \"steps\": []}, config\n",
|
|
" )\n",
|
|
"\n",
|
|
" return {\"response\": state_dict[\"generation\"], \"steps\": state_dict[\"steps\"]}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "5261f17e-3b6a-43df-ad5d-17ad9639e8dd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'response': 'The standard deduction is a fixed amount that most taxpayers can claim, while itemized deductions are specific expenses like mortgage interest, charitable donations, and medical expenses that can be deducted from taxable income. Taxpayers choose the option that gives them the lowest overall tax.',\n",
|
|
" 'steps': ['retrieve_documents',\n",
|
|
" 'grade_document_retrieval',\n",
|
|
" 'generate_answer']}"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"# Here we define an example input question about the difference between standard deduction and itemized deduction,\n",
|
|
"# and then uses the `predict_custom_agent_answer` function to generate a response based on the input and show it.\n",
|
|
"# Since, this question is related to tax deductions, the agent should provide an answer based on the loaded tax documents.\n",
|
|
"\"\"\"\n",
|
|
"example = {\n",
|
|
" \"input\": \"What is the difference between standard deduction and itemized deduction?\"\n",
|
|
"}\n",
|
|
"response = predict_custom_agent_answer(example)\n",
|
|
"response"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "627e38d9-3e0a-4094-b1fd-917fb89cc5bb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'response': 'India won the 2024 cricket world cup and Virat Kohli was named Player of the Match for. The final match was played between India and South Africa on June 29, 2024. India defeated South Africa by 7 runs to win their second T20 World Cup title.',\n",
|
|
" 'steps': ['retrieve_documents',\n",
|
|
" 'grade_document_retrieval',\n",
|
|
" 'web_search',\n",
|
|
" 'generate_answer']}"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"# Here we define another example input question about the sports event,\n",
|
|
"# and then uses the `predict_custom_agent_answer` function to generate a response based on the input and show it.\n",
|
|
"# Since, this question is NOT related to tax deductions, the agent should provide an answer based on the documents returned from web search.\n",
|
|
"\"\"\"\n",
|
|
"example = {\"input\": \"Who won the 2024 cricket world cup and who was the MVP in final?\"}\n",
|
|
"response = predict_custom_agent_answer(example)\n",
|
|
"response"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2caa78d6-f2aa-41eb-9298-f16ba6e467ba",
|
|
"metadata": {},
|
|
"source": [
|
|
"As demonstrated in the previous examples, the RAG agent routes the control flow through web search to generate answers for non-TAX related questions. For TAX related queries, it uses documents retrieved from the vector database."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "87a59300-c8ab-4281-9a31-25d37a5149f3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "test-env-langchain",
|
|
"language": "python",
|
|
"name": "test-env-langchain"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|