mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-08 13:50:00 +00:00
notebook fmt (#12498)
This commit is contained in:
@@ -115,7 +115,9 @@
|
||||
"agent = (\n",
|
||||
" {\n",
|
||||
" \"question\": lambda x: x[\"question\"],\n",
|
||||
" \"intermediate_steps\": lambda x: convert_intermediate_steps(x[\"intermediate_steps\"])\n",
|
||||
" \"intermediate_steps\": lambda x: convert_intermediate_steps(\n",
|
||||
" x[\"intermediate_steps\"]\n",
|
||||
" ),\n",
|
||||
" }\n",
|
||||
" | prompt.partial(tools=convert_tools(tool_list))\n",
|
||||
" | model.bind(stop=[\"</tool_input>\", \"</final_answer>\"])\n",
|
||||
|
@@ -18,7 +18,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
|
||||
"from langchain.prompts import (\n",
|
||||
" ChatPromptTemplate,\n",
|
||||
" SystemMessagePromptTemplate,\n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
")\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain_experimental.utilities import PythonREPL"
|
||||
]
|
||||
@@ -37,9 +41,7 @@
|
||||
"```python\n",
|
||||
"....\n",
|
||||
"```\"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"system\", template), (\"human\", \"{input}\")]\n",
|
||||
")\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", template), (\"human\", \"{input}\")])\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI()"
|
||||
]
|
||||
|
@@ -24,11 +24,13 @@
|
||||
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI()\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" (\"system\", \"You are a helpful chatbot\"),\n",
|
||||
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
||||
" (\"human\", \"{input}\")\n",
|
||||
"])\n"
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You are a helpful chatbot\"),\n",
|
||||
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -38,7 +40,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory = ConversationBufferMemory(return_messages=True)\n"
|
||||
"memory = ConversationBufferMemory(return_messages=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -59,7 +61,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"memory.load_memory_variables({})\n"
|
||||
"memory.load_memory_variables({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -69,9 +71,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = RunnablePassthrough.assign(\n",
|
||||
" memory=RunnableLambda(memory.load_memory_variables) | itemgetter(\"history\")\n",
|
||||
") | prompt | model\n"
|
||||
"chain = (\n",
|
||||
" RunnablePassthrough.assign(\n",
|
||||
" memory=RunnableLambda(memory.load_memory_variables) | itemgetter(\"history\")\n",
|
||||
" )\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -94,7 +100,7 @@
|
||||
"source": [
|
||||
"inputs = {\"input\": \"hi im bob\"}\n",
|
||||
"response = chain.invoke(inputs)\n",
|
||||
"response\n"
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -104,7 +110,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory.save_context(inputs, {\"output\": response.content})\n"
|
||||
"memory.save_context(inputs, {\"output\": response.content})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -126,7 +132,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"memory.load_memory_variables({})\n"
|
||||
"memory.load_memory_variables({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -149,7 +155,7 @@
|
||||
"source": [
|
||||
"inputs = {\"input\": \"whats my name\"}\n",
|
||||
"response = chain.invoke(inputs)\n",
|
||||
"response\n"
|
||||
"response"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@@ -40,9 +40,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = OpenAI()\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" (\"system\", \"repeat after me: {input}\")\n",
|
||||
"])"
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@@ -44,13 +44,20 @@
|
||||
"from langchain.schema import StrOutputParser\n",
|
||||
"\n",
|
||||
"prompt1 = ChatPromptTemplate.from_template(\"what is the city {person} is from?\")\n",
|
||||
"prompt2 = ChatPromptTemplate.from_template(\"what country is the city {city} in? respond in {language}\")\n",
|
||||
"prompt2 = ChatPromptTemplate.from_template(\n",
|
||||
" \"what country is the city {city} in? respond in {language}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI()\n",
|
||||
"\n",
|
||||
"chain1 = prompt1 | model | StrOutputParser()\n",
|
||||
"\n",
|
||||
"chain2 = {\"city\": chain1, \"language\": itemgetter(\"language\")} | prompt2 | model | StrOutputParser()\n",
|
||||
"chain2 = (\n",
|
||||
" {\"city\": chain1, \"language\": itemgetter(\"language\")}\n",
|
||||
" | prompt2\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain2.invoke({\"person\": \"obama\", \"language\": \"spanish\"})"
|
||||
]
|
||||
@@ -64,17 +71,29 @@
|
||||
"source": [
|
||||
"from langchain.schema.runnable import RunnableMap, RunnablePassthrough\n",
|
||||
"\n",
|
||||
"prompt1 = ChatPromptTemplate.from_template(\"generate a {attribute} color. Return the name of the color and nothing else:\")\n",
|
||||
"prompt2 = ChatPromptTemplate.from_template(\"what is a fruit of color: {color}. Return the name of the fruit and nothing else:\")\n",
|
||||
"prompt3 = ChatPromptTemplate.from_template(\"what is a country with a flag that has the color: {color}. Return the name of the country and nothing else:\")\n",
|
||||
"prompt4 = ChatPromptTemplate.from_template(\"What is the color of {fruit} and the flag of {country}?\")\n",
|
||||
"prompt1 = ChatPromptTemplate.from_template(\n",
|
||||
" \"generate a {attribute} color. Return the name of the color and nothing else:\"\n",
|
||||
")\n",
|
||||
"prompt2 = ChatPromptTemplate.from_template(\n",
|
||||
" \"what is a fruit of color: {color}. Return the name of the fruit and nothing else:\"\n",
|
||||
")\n",
|
||||
"prompt3 = ChatPromptTemplate.from_template(\n",
|
||||
" \"what is a country with a flag that has the color: {color}. Return the name of the country and nothing else:\"\n",
|
||||
")\n",
|
||||
"prompt4 = ChatPromptTemplate.from_template(\n",
|
||||
" \"What is the color of {fruit} and the flag of {country}?\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model_parser = model | StrOutputParser()\n",
|
||||
"\n",
|
||||
"color_generator = {\"attribute\": RunnablePassthrough()} | prompt1 | {\"color\": model_parser}\n",
|
||||
"color_generator = (\n",
|
||||
" {\"attribute\": RunnablePassthrough()} | prompt1 | {\"color\": model_parser}\n",
|
||||
")\n",
|
||||
"color_to_fruit = prompt2 | model_parser\n",
|
||||
"color_to_country = prompt3 | model_parser\n",
|
||||
"question_generator = color_generator | {\"fruit\": color_to_fruit, \"country\": color_to_country} | prompt4"
|
||||
"question_generator = (\n",
|
||||
" color_generator | {\"fruit\": color_to_fruit, \"country\": color_to_country} | prompt4\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -148,9 +167,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"planner = (\n",
|
||||
" ChatPromptTemplate.from_template(\n",
|
||||
" \"Generate an argument about: {input}\"\n",
|
||||
" )\n",
|
||||
" ChatPromptTemplate.from_template(\"Generate an argument about: {input}\")\n",
|
||||
" | ChatOpenAI()\n",
|
||||
" | StrOutputParser()\n",
|
||||
" | {\"base_response\": RunnablePassthrough()}\n",
|
||||
@@ -163,7 +180,7 @@
|
||||
" | ChatOpenAI()\n",
|
||||
" | StrOutputParser()\n",
|
||||
")\n",
|
||||
"arguments_against = (\n",
|
||||
"arguments_against = (\n",
|
||||
" ChatPromptTemplate.from_template(\n",
|
||||
" \"List the cons or negative aspects of {base_response}\"\n",
|
||||
" )\n",
|
||||
@@ -184,7 +201,7 @@
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = (\n",
|
||||
" planner \n",
|
||||
" planner\n",
|
||||
" | {\n",
|
||||
" \"results_1\": arguments_for,\n",
|
||||
" \"results_2\": arguments_against,\n",
|
||||
|
@@ -47,7 +47,7 @@
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")\n",
|
||||
"model = ChatOpenAI()\n",
|
||||
"chain = prompt | model\n"
|
||||
"chain = prompt | model"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -68,7 +68,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||
"chain.invoke({\"foo\": \"bears\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -94,7 +94,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = prompt | model.bind(stop=[\"\\n\"])\n"
|
||||
"chain = prompt | model.bind(stop=[\"\\n\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -115,7 +115,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||
"chain.invoke({\"foo\": \"bears\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -135,25 +135,22 @@
|
||||
"source": [
|
||||
"functions = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"joke\",\n",
|
||||
" \"description\": \"A joke\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"setup\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The setup for the joke\"\n",
|
||||
" },\n",
|
||||
" \"punchline\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The punchline for the joke\"\n",
|
||||
" }\n",
|
||||
" \"name\": \"joke\",\n",
|
||||
" \"description\": \"A joke\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"setup\": {\"type\": \"string\", \"description\": \"The setup for the joke\"},\n",
|
||||
" \"punchline\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The punchline for the joke\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"required\": [\"setup\", \"punchline\"],\n",
|
||||
" },\n",
|
||||
" \"required\": [\"setup\", \"punchline\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)\n"
|
||||
"]\n",
|
||||
"chain = prompt | model.bind(function_call={\"name\": \"joke\"}, functions=functions)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -174,7 +171,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"foo\": \"bears\"}, config={})\n"
|
||||
"chain.invoke({\"foo\": \"bears\"}, config={})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -196,7 +193,7 @@
|
||||
"source": [
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"\n",
|
||||
"chain = prompt | model | StrOutputParser()\n"
|
||||
"chain = prompt | model | StrOutputParser()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -225,7 +222,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||
"chain.invoke({\"foo\": \"bears\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -248,10 +245,10 @@
|
||||
"from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n",
|
||||
"\n",
|
||||
"chain = (\n",
|
||||
" prompt \n",
|
||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||
" prompt\n",
|
||||
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
|
||||
" | JsonOutputFunctionsParser()\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -273,7 +270,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||
"chain.invoke({\"foo\": \"bears\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -286,10 +283,10 @@
|
||||
"from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
|
||||
"\n",
|
||||
"chain = (\n",
|
||||
" prompt \n",
|
||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||
" prompt\n",
|
||||
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
|
||||
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -310,7 +307,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||
"chain.invoke({\"foo\": \"bears\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -334,11 +331,11 @@
|
||||
"\n",
|
||||
"map_ = RunnableMap(foo=RunnablePassthrough())\n",
|
||||
"chain = (\n",
|
||||
" map_ \n",
|
||||
" map_\n",
|
||||
" | prompt\n",
|
||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
|
||||
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -359,7 +356,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke(\"bears\")\n"
|
||||
"chain.invoke(\"bears\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -378,11 +375,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = (\n",
|
||||
" {\"foo\": RunnablePassthrough()} \n",
|
||||
" {\"foo\": RunnablePassthrough()}\n",
|
||||
" | prompt\n",
|
||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
|
||||
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -403,7 +400,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke(\"bears\")\n"
|
||||
"chain.invoke(\"bears\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@@ -26,7 +26,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install langchain openai faiss-cpu tiktoken\n"
|
||||
"!pip install langchain openai faiss-cpu tiktoken"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -43,7 +43,7 @@
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough, RunnableLambda\n",
|
||||
"from langchain.vectorstores import FAISS\n"
|
||||
"from langchain.vectorstores import FAISS"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -53,7 +53,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore = FAISS.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n",
|
||||
"vectorstore = FAISS.from_texts(\n",
|
||||
" [\"harrison worked at kensho\"], embedding=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"retriever = vectorstore.as_retriever()\n",
|
||||
"\n",
|
||||
"template = \"\"\"Answer the question based only on the following context:\n",
|
||||
@@ -63,7 +65,7 @@
|
||||
"\"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI()\n"
|
||||
"model = ChatOpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -74,11 +76,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = (\n",
|
||||
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
|
||||
" | prompt \n",
|
||||
" | model \n",
|
||||
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -99,7 +101,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke(\"where did harrison work?\")\n"
|
||||
"chain.invoke(\"where did harrison work?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -118,11 +120,16 @@
|
||||
"\"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"chain = {\n",
|
||||
" \"context\": itemgetter(\"question\") | retriever, \n",
|
||||
" \"question\": itemgetter(\"question\"), \n",
|
||||
" \"language\": itemgetter(\"language\")\n",
|
||||
"} | prompt | model | StrOutputParser()\n"
|
||||
"chain = (\n",
|
||||
" {\n",
|
||||
" \"context\": itemgetter(\"question\") | retriever,\n",
|
||||
" \"question\": itemgetter(\"question\"),\n",
|
||||
" \"language\": itemgetter(\"language\"),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -143,7 +150,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})\n"
|
||||
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -164,7 +171,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema.runnable import RunnableMap\n",
|
||||
"from langchain.schema import format_document\n"
|
||||
"from langchain.schema import format_document"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -182,7 +189,7 @@
|
||||
"{chat_history}\n",
|
||||
"Follow Up Input: {question}\n",
|
||||
"Standalone question:\"\"\"\n",
|
||||
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n"
|
||||
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -197,7 +204,7 @@
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"\"\"\"\n",
|
||||
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)\n"
|
||||
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -208,9 +215,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n",
|
||||
"def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _combine_documents(\n",
|
||||
" docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"\n",
|
||||
"):\n",
|
||||
" doc_strings = [format_document(doc, document_prompt) for doc in docs]\n",
|
||||
" return document_separator.join(doc_strings)\n"
|
||||
" return document_separator.join(doc_strings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -221,13 +232,15 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Tuple, List\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _format_chat_history(chat_history: List[Tuple]) -> str:\n",
|
||||
" buffer = \"\"\n",
|
||||
" for dialogue_turn in chat_history:\n",
|
||||
" human = \"Human: \" + dialogue_turn[0]\n",
|
||||
" ai = \"Assistant: \" + dialogue_turn[1]\n",
|
||||
" buffer += \"\\n\" + \"\\n\".join([human, ai])\n",
|
||||
" return buffer\n"
|
||||
" return buffer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -239,14 +252,17 @@
|
||||
"source": [
|
||||
"_inputs = RunnableMap(\n",
|
||||
" standalone_question=RunnablePassthrough.assign(\n",
|
||||
" chat_history=lambda x: _format_chat_history(x['chat_history'])\n",
|
||||
" ) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
|
||||
" chat_history=lambda x: _format_chat_history(x[\"chat_history\"])\n",
|
||||
" )\n",
|
||||
" | CONDENSE_QUESTION_PROMPT\n",
|
||||
" | ChatOpenAI(temperature=0)\n",
|
||||
" | StrOutputParser(),\n",
|
||||
")\n",
|
||||
"_context = {\n",
|
||||
" \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n",
|
||||
" \"question\": lambda x: x[\"standalone_question\"]\n",
|
||||
" \"question\": lambda x: x[\"standalone_question\"],\n",
|
||||
"}\n",
|
||||
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()\n"
|
||||
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -267,10 +283,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversational_qa_chain.invoke({\n",
|
||||
" \"question\": \"where did harrison work?\",\n",
|
||||
" \"chat_history\": [],\n",
|
||||
"})\n"
|
||||
"conversational_qa_chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"question\": \"where did harrison work?\",\n",
|
||||
" \"chat_history\": [],\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -291,10 +309,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversational_qa_chain.invoke({\n",
|
||||
" \"question\": \"where did he work?\",\n",
|
||||
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
|
||||
"})\n"
|
||||
"conversational_qa_chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"question\": \"where did he work?\",\n",
|
||||
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -315,7 +335,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from operator import itemgetter\n",
|
||||
"from langchain.memory import ConversationBufferMemory\n"
|
||||
"from langchain.memory import ConversationBufferMemory"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -325,7 +345,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")\n"
|
||||
"memory = ConversationBufferMemory(\n",
|
||||
" return_messages=True, output_key=\"answer\", input_key=\"question\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -344,18 +366,21 @@
|
||||
"standalone_question = {\n",
|
||||
" \"standalone_question\": {\n",
|
||||
" \"question\": lambda x: x[\"question\"],\n",
|
||||
" \"chat_history\": lambda x: _format_chat_history(x['chat_history'])\n",
|
||||
" } | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
|
||||
" \"chat_history\": lambda x: _format_chat_history(x[\"chat_history\"]),\n",
|
||||
" }\n",
|
||||
" | CONDENSE_QUESTION_PROMPT\n",
|
||||
" | ChatOpenAI(temperature=0)\n",
|
||||
" | StrOutputParser(),\n",
|
||||
"}\n",
|
||||
"# Now we retrieve the documents\n",
|
||||
"retrieved_documents = {\n",
|
||||
" \"docs\": itemgetter(\"standalone_question\") | retriever,\n",
|
||||
" \"question\": lambda x: x[\"standalone_question\"]\n",
|
||||
" \"question\": lambda x: x[\"standalone_question\"],\n",
|
||||
"}\n",
|
||||
"# Now we construct the inputs for the final prompt\n",
|
||||
"final_inputs = {\n",
|
||||
" \"context\": lambda x: _combine_documents(x[\"docs\"]),\n",
|
||||
" \"question\": itemgetter(\"question\")\n",
|
||||
" \"question\": itemgetter(\"question\"),\n",
|
||||
"}\n",
|
||||
"# And finally, we do the part that returns the answers\n",
|
||||
"answer = {\n",
|
||||
@@ -363,7 +388,7 @@
|
||||
" \"docs\": itemgetter(\"docs\"),\n",
|
||||
"}\n",
|
||||
"# And now we put it all together!\n",
|
||||
"final_chain = loaded_memory | standalone_question | retrieved_documents | answer\n"
|
||||
"final_chain = loaded_memory | standalone_question | retrieved_documents | answer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -387,7 +412,7 @@
|
||||
"source": [
|
||||
"inputs = {\"question\": \"where did harrison work?\"}\n",
|
||||
"result = final_chain.invoke(inputs)\n",
|
||||
"result\n"
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -400,7 +425,7 @@
|
||||
"# Note that the memory does not save automatically\n",
|
||||
"# This will be improved in the future\n",
|
||||
"# For now you need to save it yourself\n",
|
||||
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})\n"
|
||||
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -422,7 +447,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"memory.load_memory_variables({})\n"
|
||||
"memory.load_memory_variables({})"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@@ -33,7 +33,7 @@
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"SQL Query:\"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(template)\n"
|
||||
"prompt = ChatPromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -43,7 +43,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.utilities import SQLDatabase\n"
|
||||
"from langchain.utilities import SQLDatabase"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -61,7 +61,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")\n"
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -72,7 +72,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_schema(_):\n",
|
||||
" return db.get_table_info()\n"
|
||||
" return db.get_table_info()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -83,7 +83,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def run_query(query):\n",
|
||||
" return db.run(query)\n"
|
||||
" return db.run(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -100,11 +100,11 @@
|
||||
"model = ChatOpenAI()\n",
|
||||
"\n",
|
||||
"sql_response = (\n",
|
||||
" RunnablePassthrough.assign(schema=get_schema)\n",
|
||||
" | prompt\n",
|
||||
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
|
||||
" | StrOutputParser()\n",
|
||||
" )\n"
|
||||
" RunnablePassthrough.assign(schema=get_schema)\n",
|
||||
" | prompt\n",
|
||||
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
|
||||
" | StrOutputParser()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -125,7 +125,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sql_response.invoke({\"question\": \"How many employees are there?\"})\n"
|
||||
"sql_response.invoke({\"question\": \"How many employees are there?\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -141,7 +141,7 @@
|
||||
"Question: {question}\n",
|
||||
"SQL Query: {query}\n",
|
||||
"SQL Response: {response}\"\"\"\n",
|
||||
"prompt_response = ChatPromptTemplate.from_template(template)\n"
|
||||
"prompt_response = ChatPromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -152,14 +152,14 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"full_chain = (\n",
|
||||
" RunnablePassthrough.assign(query=sql_response) \n",
|
||||
" RunnablePassthrough.assign(query=sql_response)\n",
|
||||
" | RunnablePassthrough.assign(\n",
|
||||
" schema=get_schema,\n",
|
||||
" response=lambda x: db.run(x[\"query\"]),\n",
|
||||
" )\n",
|
||||
" | prompt_response \n",
|
||||
" | prompt_response\n",
|
||||
" | model\n",
|
||||
")\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -180,7 +180,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"full_chain.invoke({\"question\": \"How many employees are there?\"})\n"
|
||||
"full_chain.invoke({\"question\": \"How many employees are there?\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Reference in New Issue
Block a user