langchain : Add the full code snippet in rag.ipynb (#29820)

docs(rag.ipynb) : Add the `full code` snippet, it’s necessary and useful
for beginners to demonstrate.

Preview the change :
https://langchain-git-fork-googtech-patch-3-langchain.vercel.app/docs/tutorials/rag/

Two `full code` snippets are added as below :
<details>
<summary>Full Code:</summary>

```python
import bs4
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chat_models import init_chat_model
from langchain_openai import OpenAIEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from google.colab import userdata
from langchain_core.prompts import PromptTemplate
from langchain_core.documents import Document
from typing_extensions import List, TypedDict
from langgraph.graph import START, StateGraph

#################################################
# 1.Initialize the ChatModel and EmbeddingModel #
#################################################
llm = init_chat_model(
    model="gpt-4o-mini",
    model_provider="openai",
    openai_api_key=userdata.get('OPENAI_API_KEY'),
    base_url=userdata.get('BASE_URL'),
)
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-large",
    openai_api_key=userdata.get('OPENAI_API_KEY'),
    base_url=userdata.get('BASE_URL'),
)

#######################
# 2.Loading documents #
#######################
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs=dict(
        # Only keep post title, headers, and content from the full HTML.
        parse_only=bs4.SoupStrainer(
            class_=("post-content", "post-title", "post-header")
        )
    ),
)
docs = loader.load()

#########################
# 3.Splitting documents #
#########################
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,  # chunk size (characters)
    chunk_overlap=200,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)
all_splits = text_splitter.split_documents(docs)

###########################################################
# 4.Embedding documents and storing them in a vectorstore #
###########################################################
vector_store = InMemoryVectorStore(embeddings)
_ = vector_store.add_documents(documents=all_splits)

##########################################################
# 5.Customizing the prompt or loading it from Prompt Hub #
##########################################################
# prompt = hub.pull("rlm/rag-prompt") # load the prompt from the prompt-hub
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.

{context}

Question: {question}

Helpful Answer:"""
prompt = PromptTemplate.from_template(template)

##################################################################################################
# 5.Using LangGraph to tie together the retrieval and generation steps into a single application #                               #
##################################################################################################
# 5.1.Define the state of application, which controls the application datas
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

# 5.2.1.Define the node of application, which signifies the application steps
def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"])
    return {"context": retrieved_docs}

# 5.2.2.Define the node of application, which signifies the application steps
def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    return {"answer": response.content}

# 6.Define the "control flow" of application, which signifies the ordering of the application steps
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
```

</details>

<details>
<summary>Full Code:</summary>

```python
import bs4
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chat_models import init_chat_model
from langchain_openai import OpenAIEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from google.colab import userdata
from langchain_core.prompts import PromptTemplate
from langchain_core.documents import Document
from typing_extensions import List, TypedDict
from langgraph.graph import START, StateGraph
from typing import Literal
from typing_extensions import Annotated

#################################################
# 1.Initialize the ChatModel and EmbeddingModel #
#################################################
llm = init_chat_model(
    model="gpt-4o-mini",
    model_provider="openai",
    openai_api_key=userdata.get('OPENAI_API_KEY'),
    base_url=userdata.get('BASE_URL'),
)
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-large",
    openai_api_key=userdata.get('OPENAI_API_KEY'),
    base_url=userdata.get('BASE_URL'),
)

#######################
# 2.Loading documents #
#######################
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs=dict(
        # Only keep post title, headers, and content from the full HTML.
        parse_only=bs4.SoupStrainer(
            class_=("post-content", "post-title", "post-header")
        )
    ),
)
docs = loader.load()

#########################
# 3.Splitting documents #
#########################
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,  # chunk size (characters)
    chunk_overlap=200,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)
all_splits = text_splitter.split_documents(docs)

# Search analysis: Add some metadata to the documents in our vector store,
# so that we can filter on section later. 
total_documents = len(all_splits)
third = total_documents // 3
for i, document in enumerate(all_splits):
    if i < third:
        document.metadata["section"] = "beginning"
    elif i < 2 * third:
        document.metadata["section"] = "middle"
    else:
        document.metadata["section"] = "end"

# Search analysis: Define the schema for our search query
class Search(TypedDict):
    query: Annotated[str, ..., "Search query to run."]
    section: Annotated[
        Literal["beginning", "middle", "end"], ..., "Section to query."]

###########################################################
# 4.Embedding documents and storing them in a vectorstore #
###########################################################
vector_store = InMemoryVectorStore(embeddings)
_ = vector_store.add_documents(documents=all_splits)

##########################################################
# 5.Customizing the prompt or loading it from Prompt Hub #
##########################################################
# prompt = hub.pull("rlm/rag-prompt") # load the prompt from the prompt-hub
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.

{context}

Question: {question}

Helpful Answer:"""
prompt = PromptTemplate.from_template(template)

###################################################################
# 5.Using LangGraph to tie together the analyze_query, retrieval  #
# and generation steps into a single application                  #
###################################################################
# 5.1.Define the state of application, which controls the application datas
class State(TypedDict):
    question: str
    query: Search
    context: List[Document]
    answer: str

# Search analysis: Define the node of application, 
# which be used to generate a query from the user's raw input
def analyze_query(state: State):
    structured_llm = llm.with_structured_output(Search)
    query = structured_llm.invoke(state["question"])
    return {"query": query}

# 5.2.1.Define the node of application, which signifies the application steps
def retrieve(state: State):
    query = state["query"]
    retrieved_docs = vector_store.similarity_search(
        query["query"],
        filter=lambda doc: doc.metadata.get("section") == query["section"],
    )
    return {"context": retrieved_docs}

# 5.2.2.Define the node of application, which signifies the application steps
def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    return {"answer": response.content}

# 6.Define the "control flow" of application, which signifies the ordering of the application steps
graph_builder = StateGraph(State).add_sequence([analyze_query, retrieve, generate]) 
graph_builder.add_edge(START, "analyze_query")
graph = graph_builder.compile()
```

</details>

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
HackHuang
2025-02-16 10:37:58 +08:00
committed by GitHub
parent b2c21f3e57
commit 80ca310c15

View File

@@ -1050,6 +1050,112 @@
"graph = graph_builder.compile()"
]
},
{
"cell_type": "markdown",
"id": "28a62d34",
"metadata": {},
"source": [
"<details>\n",
"<summary>Full Code:</summary>\n",
"\n",
"```python\n",
"from typing import Literal\n",
"\n",
"import bs4\n",
"from langchain import hub\n",
"from langchain_community.document_loaders import WebBaseLoader\n",
"from langchain_core.documents import Document\n",
"from langchain_core.vectorstores import InMemoryVectorStore\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"from langgraph.graph import START, StateGraph\n",
"from typing_extensions import Annotated, List, TypedDict\n",
"\n",
"# Load and chunk contents of the blog\n",
"loader = WebBaseLoader(\n",
" web_paths=(\"https://lilianweng.github.io/posts/2023-06-23-agent/\",),\n",
" bs_kwargs=dict(\n",
" parse_only=bs4.SoupStrainer(\n",
" class_=(\"post-content\", \"post-title\", \"post-header\")\n",
" )\n",
" ),\n",
")\n",
"docs = loader.load()\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
"all_splits = text_splitter.split_documents(docs)\n",
"\n",
"\n",
"# Update metadata (illustration purposes)\n",
"total_documents = len(all_splits)\n",
"third = total_documents // 3\n",
"\n",
"for i, document in enumerate(all_splits):\n",
" if i < third:\n",
" document.metadata[\"section\"] = \"beginning\"\n",
" elif i < 2 * third:\n",
" document.metadata[\"section\"] = \"middle\"\n",
" else:\n",
" document.metadata[\"section\"] = \"end\"\n",
"\n",
"\n",
"# Index chunks\n",
"vector_store = InMemoryVectorStore(embeddings)\n",
"_ = vector_store.add_documents(all_splits)\n",
"\n",
"\n",
"# Define schema for search\n",
"class Search(TypedDict):\n",
" \"\"\"Search query.\"\"\"\n",
"\n",
" query: Annotated[str, ..., \"Search query to run.\"]\n",
" section: Annotated[\n",
" Literal[\"beginning\", \"middle\", \"end\"],\n",
" ...,\n",
" \"Section to query.\",\n",
" ]\n",
"\n",
"# Define prompt for question-answering\n",
"prompt = hub.pull(\"rlm/rag-prompt\")\n",
"\n",
"\n",
"# Define state for application\n",
"class State(TypedDict):\n",
" question: str\n",
" query: Search\n",
" context: List[Document]\n",
" answer: str\n",
"\n",
"\n",
"def analyze_query(state: State):\n",
" structured_llm = llm.with_structured_output(Search)\n",
" query = structured_llm.invoke(state[\"question\"])\n",
" return {\"query\": query}\n",
"\n",
"\n",
"def retrieve(state: State):\n",
" query = state[\"query\"]\n",
" retrieved_docs = vector_store.similarity_search(\n",
" query[\"query\"],\n",
" filter=lambda doc: doc.metadata.get(\"section\") == query[\"section\"],\n",
" )\n",
" return {\"context\": retrieved_docs}\n",
"\n",
"\n",
"def generate(state: State):\n",
" docs_content = \"\\n\\n\".join(doc.page_content for doc in state[\"context\"])\n",
" messages = prompt.invoke({\"question\": state[\"question\"], \"context\": docs_content})\n",
" response = llm.invoke(messages)\n",
" return {\"answer\": response.content}\n",
"\n",
"\n",
"graph_builder = StateGraph(State).add_sequence([analyze_query, retrieve, generate])\n",
"graph_builder.add_edge(START, \"analyze_query\")\n",
"graph = graph_builder.compile()\n",
"```\n",
"\n",
"</details>"
]
},
{
"cell_type": "code",
"execution_count": 25,