diff --git a/docs/docs/tutorials/rag.ipynb b/docs/docs/tutorials/rag.ipynb index 511378e3ff9..a4601ce3a21 100644 --- a/docs/docs/tutorials/rag.ipynb +++ b/docs/docs/tutorials/rag.ipynb @@ -1050,6 +1050,112 @@ "graph = graph_builder.compile()" ] }, + { + "cell_type": "markdown", + "id": "28a62d34", + "metadata": {}, + "source": [ + "
\n", + "Full Code:\n", + "\n", + "```python\n", + "from typing import Literal\n", + "\n", + "import bs4\n", + "from langchain import hub\n", + "from langchain_community.document_loaders import WebBaseLoader\n", + "from langchain_core.documents import Document\n", + "from langchain_core.vectorstores import InMemoryVectorStore\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "from langgraph.graph import START, StateGraph\n", + "from typing_extensions import Annotated, List, TypedDict\n", + "\n", + "# Load and chunk contents of the blog\n", + "loader = WebBaseLoader(\n", + " web_paths=(\"https://lilianweng.github.io/posts/2023-06-23-agent/\",),\n", + " bs_kwargs=dict(\n", + " parse_only=bs4.SoupStrainer(\n", + " class_=(\"post-content\", \"post-title\", \"post-header\")\n", + " )\n", + " ),\n", + ")\n", + "docs = loader.load()\n", + "\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", + "all_splits = text_splitter.split_documents(docs)\n", + "\n", + "\n", + "# Update metadata (illustration purposes)\n", + "total_documents = len(all_splits)\n", + "third = total_documents // 3\n", + "\n", + "for i, document in enumerate(all_splits):\n", + " if i < third:\n", + " document.metadata[\"section\"] = \"beginning\"\n", + " elif i < 2 * third:\n", + " document.metadata[\"section\"] = \"middle\"\n", + " else:\n", + " document.metadata[\"section\"] = \"end\"\n", + "\n", + "\n", + "# Index chunks\n", + "vector_store = InMemoryVectorStore(embeddings)\n", + "_ = vector_store.add_documents(all_splits)\n", + "\n", + "\n", + "# Define schema for search\n", + "class Search(TypedDict):\n", + " \"\"\"Search query.\"\"\"\n", + "\n", + " query: Annotated[str, ..., \"Search query to run.\"]\n", + " section: Annotated[\n", + " Literal[\"beginning\", \"middle\", \"end\"],\n", + " ...,\n", + " \"Section to query.\",\n", + " ]\n", + "\n", + "# Define prompt for question-answering\n", + "prompt = hub.pull(\"rlm/rag-prompt\")\n", + "\n", + "\n", + "# Define state for application\n", + "class State(TypedDict):\n", + " question: str\n", + " query: Search\n", + " context: List[Document]\n", + " answer: str\n", + "\n", + "\n", + "def analyze_query(state: State):\n", + " structured_llm = llm.with_structured_output(Search)\n", + " query = structured_llm.invoke(state[\"question\"])\n", + " return {\"query\": query}\n", + "\n", + "\n", + "def retrieve(state: State):\n", + " query = state[\"query\"]\n", + " retrieved_docs = vector_store.similarity_search(\n", + " query[\"query\"],\n", + " filter=lambda doc: doc.metadata.get(\"section\") == query[\"section\"],\n", + " )\n", + " return {\"context\": retrieved_docs}\n", + "\n", + "\n", + "def generate(state: State):\n", + " docs_content = \"\\n\\n\".join(doc.page_content for doc in state[\"context\"])\n", + " messages = prompt.invoke({\"question\": state[\"question\"], \"context\": docs_content})\n", + " response = llm.invoke(messages)\n", + " return {\"answer\": response.content}\n", + "\n", + "\n", + "graph_builder = StateGraph(State).add_sequence([analyze_query, retrieve, generate])\n", + "graph_builder.add_edge(START, \"analyze_query\")\n", + "graph = graph_builder.compile()\n", + "```\n", + "\n", + "
" + ] + }, { "cell_type": "code", "execution_count": 25,