notebook fmt (#12498)

This commit is contained in:
Bagatur
2023-10-29 15:50:09 -07:00
committed by GitHub
parent 56cc5b847c
commit 2424fff3f1
342 changed files with 8261 additions and 6796 deletions

View File

@@ -115,7 +115,9 @@
"agent = (\n",
" {\n",
" \"question\": lambda x: x[\"question\"],\n",
" \"intermediate_steps\": lambda x: convert_intermediate_steps(x[\"intermediate_steps\"])\n",
" \"intermediate_steps\": lambda x: convert_intermediate_steps(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",
" | prompt.partial(tools=convert_tools(tool_list))\n",
" | model.bind(stop=[\"</tool_input>\", \"</final_answer>\"])\n",

View File

@@ -18,7 +18,11 @@
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
"from langchain.prompts import (\n",
" ChatPromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
")\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain_experimental.utilities import PythonREPL"
]
@@ -37,9 +41,7 @@
"```python\n",
"....\n",
"```\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", template), (\"human\", \"{input}\")]\n",
")\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", template), (\"human\", \"{input}\")])\n",
"\n",
"model = ChatOpenAI()"
]

View File

@@ -24,11 +24,13 @@
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"\n",
"model = ChatOpenAI()\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"You are a helpful chatbot\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", \"{input}\")\n",
"])\n"
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are a helpful chatbot\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")"
]
},
{
@@ -38,7 +40,7 @@
"metadata": {},
"outputs": [],
"source": [
"memory = ConversationBufferMemory(return_messages=True)\n"
"memory = ConversationBufferMemory(return_messages=True)"
]
},
{
@@ -59,7 +61,7 @@
}
],
"source": [
"memory.load_memory_variables({})\n"
"memory.load_memory_variables({})"
]
},
{
@@ -69,9 +71,13 @@
"metadata": {},
"outputs": [],
"source": [
"chain = RunnablePassthrough.assign(\n",
" memory=RunnableLambda(memory.load_memory_variables) | itemgetter(\"history\")\n",
") | prompt | model\n"
"chain = (\n",
" RunnablePassthrough.assign(\n",
" memory=RunnableLambda(memory.load_memory_variables) | itemgetter(\"history\")\n",
" )\n",
" | prompt\n",
" | model\n",
")"
]
},
{
@@ -94,7 +100,7 @@
"source": [
"inputs = {\"input\": \"hi im bob\"}\n",
"response = chain.invoke(inputs)\n",
"response\n"
"response"
]
},
{
@@ -104,7 +110,7 @@
"metadata": {},
"outputs": [],
"source": [
"memory.save_context(inputs, {\"output\": response.content})\n"
"memory.save_context(inputs, {\"output\": response.content})"
]
},
{
@@ -126,7 +132,7 @@
}
],
"source": [
"memory.load_memory_variables({})\n"
"memory.load_memory_variables({})"
]
},
{
@@ -149,7 +155,7 @@
"source": [
"inputs = {\"input\": \"whats my name\"}\n",
"response = chain.invoke(inputs)\n",
"response\n"
"response"
]
}
],

View File

@@ -40,9 +40,7 @@
"outputs": [],
"source": [
"model = OpenAI()\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"repeat after me: {input}\")\n",
"])"
"prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])"
]
},
{

View File

@@ -44,13 +44,20 @@
"from langchain.schema import StrOutputParser\n",
"\n",
"prompt1 = ChatPromptTemplate.from_template(\"what is the city {person} is from?\")\n",
"prompt2 = ChatPromptTemplate.from_template(\"what country is the city {city} in? respond in {language}\")\n",
"prompt2 = ChatPromptTemplate.from_template(\n",
" \"what country is the city {city} in? respond in {language}\"\n",
")\n",
"\n",
"model = ChatOpenAI()\n",
"\n",
"chain1 = prompt1 | model | StrOutputParser()\n",
"\n",
"chain2 = {\"city\": chain1, \"language\": itemgetter(\"language\")} | prompt2 | model | StrOutputParser()\n",
"chain2 = (\n",
" {\"city\": chain1, \"language\": itemgetter(\"language\")}\n",
" | prompt2\n",
" | model\n",
" | StrOutputParser()\n",
")\n",
"\n",
"chain2.invoke({\"person\": \"obama\", \"language\": \"spanish\"})"
]
@@ -64,17 +71,29 @@
"source": [
"from langchain.schema.runnable import RunnableMap, RunnablePassthrough\n",
"\n",
"prompt1 = ChatPromptTemplate.from_template(\"generate a {attribute} color. Return the name of the color and nothing else:\")\n",
"prompt2 = ChatPromptTemplate.from_template(\"what is a fruit of color: {color}. Return the name of the fruit and nothing else:\")\n",
"prompt3 = ChatPromptTemplate.from_template(\"what is a country with a flag that has the color: {color}. Return the name of the country and nothing else:\")\n",
"prompt4 = ChatPromptTemplate.from_template(\"What is the color of {fruit} and the flag of {country}?\")\n",
"prompt1 = ChatPromptTemplate.from_template(\n",
" \"generate a {attribute} color. Return the name of the color and nothing else:\"\n",
")\n",
"prompt2 = ChatPromptTemplate.from_template(\n",
" \"what is a fruit of color: {color}. Return the name of the fruit and nothing else:\"\n",
")\n",
"prompt3 = ChatPromptTemplate.from_template(\n",
" \"what is a country with a flag that has the color: {color}. Return the name of the country and nothing else:\"\n",
")\n",
"prompt4 = ChatPromptTemplate.from_template(\n",
" \"What is the color of {fruit} and the flag of {country}?\"\n",
")\n",
"\n",
"model_parser = model | StrOutputParser()\n",
"\n",
"color_generator = {\"attribute\": RunnablePassthrough()} | prompt1 | {\"color\": model_parser}\n",
"color_generator = (\n",
" {\"attribute\": RunnablePassthrough()} | prompt1 | {\"color\": model_parser}\n",
")\n",
"color_to_fruit = prompt2 | model_parser\n",
"color_to_country = prompt3 | model_parser\n",
"question_generator = color_generator | {\"fruit\": color_to_fruit, \"country\": color_to_country} | prompt4"
"question_generator = (\n",
" color_generator | {\"fruit\": color_to_fruit, \"country\": color_to_country} | prompt4\n",
")"
]
},
{
@@ -148,9 +167,7 @@
"outputs": [],
"source": [
"planner = (\n",
" ChatPromptTemplate.from_template(\n",
" \"Generate an argument about: {input}\"\n",
" )\n",
" ChatPromptTemplate.from_template(\"Generate an argument about: {input}\")\n",
" | ChatOpenAI()\n",
" | StrOutputParser()\n",
" | {\"base_response\": RunnablePassthrough()}\n",
@@ -163,7 +180,7 @@
" | ChatOpenAI()\n",
" | StrOutputParser()\n",
")\n",
"arguments_against = (\n",
"arguments_against = (\n",
" ChatPromptTemplate.from_template(\n",
" \"List the cons or negative aspects of {base_response}\"\n",
" )\n",
@@ -184,7 +201,7 @@
")\n",
"\n",
"chain = (\n",
" planner \n",
" planner\n",
" | {\n",
" \"results_1\": arguments_for,\n",
" \"results_2\": arguments_against,\n",

View File

@@ -47,7 +47,7 @@
"\n",
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")\n",
"model = ChatOpenAI()\n",
"chain = prompt | model\n"
"chain = prompt | model"
]
},
{
@@ -68,7 +68,7 @@
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})\n"
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
@@ -94,7 +94,7 @@
"metadata": {},
"outputs": [],
"source": [
"chain = prompt | model.bind(stop=[\"\\n\"])\n"
"chain = prompt | model.bind(stop=[\"\\n\"])"
]
},
{
@@ -115,7 +115,7 @@
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})\n"
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
@@ -135,25 +135,22 @@
"source": [
"functions = [\n",
" {\n",
" \"name\": \"joke\",\n",
" \"description\": \"A joke\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"setup\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The setup for the joke\"\n",
" },\n",
" \"punchline\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The punchline for the joke\"\n",
" }\n",
" \"name\": \"joke\",\n",
" \"description\": \"A joke\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"setup\": {\"type\": \"string\", \"description\": \"The setup for the joke\"},\n",
" \"punchline\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The punchline for the joke\",\n",
" },\n",
" },\n",
" \"required\": [\"setup\", \"punchline\"],\n",
" },\n",
" \"required\": [\"setup\", \"punchline\"]\n",
" }\n",
" }\n",
" ]\n",
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)\n"
"]\n",
"chain = prompt | model.bind(function_call={\"name\": \"joke\"}, functions=functions)"
]
},
{
@@ -174,7 +171,7 @@
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"}, config={})\n"
"chain.invoke({\"foo\": \"bears\"}, config={})"
]
},
{
@@ -196,7 +193,7 @@
"source": [
"from langchain.schema.output_parser import StrOutputParser\n",
"\n",
"chain = prompt | model | StrOutputParser()\n"
"chain = prompt | model | StrOutputParser()"
]
},
{
@@ -225,7 +222,7 @@
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})\n"
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
@@ -248,10 +245,10 @@
"from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n",
"\n",
"chain = (\n",
" prompt \n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" prompt\n",
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
" | JsonOutputFunctionsParser()\n",
")\n"
")"
]
},
{
@@ -273,7 +270,7 @@
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})\n"
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
@@ -286,10 +283,10 @@
"from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
"\n",
"chain = (\n",
" prompt \n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" prompt\n",
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")\n"
")"
]
},
{
@@ -310,7 +307,7 @@
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})\n"
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
@@ -334,11 +331,11 @@
"\n",
"map_ = RunnableMap(foo=RunnablePassthrough())\n",
"chain = (\n",
" map_ \n",
" map_\n",
" | prompt\n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")\n"
")"
]
},
{
@@ -359,7 +356,7 @@
}
],
"source": [
"chain.invoke(\"bears\")\n"
"chain.invoke(\"bears\")"
]
},
{
@@ -378,11 +375,11 @@
"outputs": [],
"source": [
"chain = (\n",
" {\"foo\": RunnablePassthrough()} \n",
" {\"foo\": RunnablePassthrough()}\n",
" | prompt\n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")\n"
")"
]
},
{
@@ -403,7 +400,7 @@
}
],
"source": [
"chain.invoke(\"bears\")\n"
"chain.invoke(\"bears\")"
]
}
],

View File

@@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install langchain openai faiss-cpu tiktoken\n"
"!pip install langchain openai faiss-cpu tiktoken"
]
},
{
@@ -43,7 +43,7 @@
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough, RunnableLambda\n",
"from langchain.vectorstores import FAISS\n"
"from langchain.vectorstores import FAISS"
]
},
{
@@ -53,7 +53,9 @@
"metadata": {},
"outputs": [],
"source": [
"vectorstore = FAISS.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n",
"vectorstore = FAISS.from_texts(\n",
" [\"harrison worked at kensho\"], embedding=OpenAIEmbeddings()\n",
")\n",
"retriever = vectorstore.as_retriever()\n",
"\n",
"template = \"\"\"Answer the question based only on the following context:\n",
@@ -63,7 +65,7 @@
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"model = ChatOpenAI()\n"
"model = ChatOpenAI()"
]
},
{
@@ -74,11 +76,11 @@
"outputs": [],
"source": [
"chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
" | prompt \n",
" | model \n",
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | model\n",
" | StrOutputParser()\n",
")\n"
")"
]
},
{
@@ -99,7 +101,7 @@
}
],
"source": [
"chain.invoke(\"where did harrison work?\")\n"
"chain.invoke(\"where did harrison work?\")"
]
},
{
@@ -118,11 +120,16 @@
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"chain = {\n",
" \"context\": itemgetter(\"question\") | retriever, \n",
" \"question\": itemgetter(\"question\"), \n",
" \"language\": itemgetter(\"language\")\n",
"} | prompt | model | StrOutputParser()\n"
"chain = (\n",
" {\n",
" \"context\": itemgetter(\"question\") | retriever,\n",
" \"question\": itemgetter(\"question\"),\n",
" \"language\": itemgetter(\"language\"),\n",
" }\n",
" | prompt\n",
" | model\n",
" | StrOutputParser()\n",
")"
]
},
{
@@ -143,7 +150,7 @@
}
],
"source": [
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})\n"
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})"
]
},
{
@@ -164,7 +171,7 @@
"outputs": [],
"source": [
"from langchain.schema.runnable import RunnableMap\n",
"from langchain.schema import format_document\n"
"from langchain.schema import format_document"
]
},
{
@@ -182,7 +189,7 @@
"{chat_history}\n",
"Follow Up Input: {question}\n",
"Standalone question:\"\"\"\n",
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n"
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)"
]
},
{
@@ -197,7 +204,7 @@
"\n",
"Question: {question}\n",
"\"\"\"\n",
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)\n"
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)"
]
},
{
@@ -208,9 +215,13 @@
"outputs": [],
"source": [
"DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n",
"def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n",
"\n",
"\n",
"def _combine_documents(\n",
" docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"\n",
"):\n",
" doc_strings = [format_document(doc, document_prompt) for doc in docs]\n",
" return document_separator.join(doc_strings)\n"
" return document_separator.join(doc_strings)"
]
},
{
@@ -221,13 +232,15 @@
"outputs": [],
"source": [
"from typing import Tuple, List\n",
"\n",
"\n",
"def _format_chat_history(chat_history: List[Tuple]) -> str:\n",
" buffer = \"\"\n",
" for dialogue_turn in chat_history:\n",
" human = \"Human: \" + dialogue_turn[0]\n",
" ai = \"Assistant: \" + dialogue_turn[1]\n",
" buffer += \"\\n\" + \"\\n\".join([human, ai])\n",
" return buffer\n"
" return buffer"
]
},
{
@@ -239,14 +252,17 @@
"source": [
"_inputs = RunnableMap(\n",
" standalone_question=RunnablePassthrough.assign(\n",
" chat_history=lambda x: _format_chat_history(x['chat_history'])\n",
" ) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
" chat_history=lambda x: _format_chat_history(x[\"chat_history\"])\n",
" )\n",
" | CONDENSE_QUESTION_PROMPT\n",
" | ChatOpenAI(temperature=0)\n",
" | StrOutputParser(),\n",
")\n",
"_context = {\n",
" \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n",
" \"question\": lambda x: x[\"standalone_question\"]\n",
" \"question\": lambda x: x[\"standalone_question\"],\n",
"}\n",
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()\n"
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()"
]
},
{
@@ -267,10 +283,12 @@
}
],
"source": [
"conversational_qa_chain.invoke({\n",
" \"question\": \"where did harrison work?\",\n",
" \"chat_history\": [],\n",
"})\n"
"conversational_qa_chain.invoke(\n",
" {\n",
" \"question\": \"where did harrison work?\",\n",
" \"chat_history\": [],\n",
" }\n",
")"
]
},
{
@@ -291,10 +309,12 @@
}
],
"source": [
"conversational_qa_chain.invoke({\n",
" \"question\": \"where did he work?\",\n",
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
"})\n"
"conversational_qa_chain.invoke(\n",
" {\n",
" \"question\": \"where did he work?\",\n",
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
" }\n",
")"
]
},
{
@@ -315,7 +335,7 @@
"outputs": [],
"source": [
"from operator import itemgetter\n",
"from langchain.memory import ConversationBufferMemory\n"
"from langchain.memory import ConversationBufferMemory"
]
},
{
@@ -325,7 +345,9 @@
"metadata": {},
"outputs": [],
"source": [
"memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")\n"
"memory = ConversationBufferMemory(\n",
" return_messages=True, output_key=\"answer\", input_key=\"question\"\n",
")"
]
},
{
@@ -344,18 +366,21 @@
"standalone_question = {\n",
" \"standalone_question\": {\n",
" \"question\": lambda x: x[\"question\"],\n",
" \"chat_history\": lambda x: _format_chat_history(x['chat_history'])\n",
" } | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
" \"chat_history\": lambda x: _format_chat_history(x[\"chat_history\"]),\n",
" }\n",
" | CONDENSE_QUESTION_PROMPT\n",
" | ChatOpenAI(temperature=0)\n",
" | StrOutputParser(),\n",
"}\n",
"# Now we retrieve the documents\n",
"retrieved_documents = {\n",
" \"docs\": itemgetter(\"standalone_question\") | retriever,\n",
" \"question\": lambda x: x[\"standalone_question\"]\n",
" \"question\": lambda x: x[\"standalone_question\"],\n",
"}\n",
"# Now we construct the inputs for the final prompt\n",
"final_inputs = {\n",
" \"context\": lambda x: _combine_documents(x[\"docs\"]),\n",
" \"question\": itemgetter(\"question\")\n",
" \"question\": itemgetter(\"question\"),\n",
"}\n",
"# And finally, we do the part that returns the answers\n",
"answer = {\n",
@@ -363,7 +388,7 @@
" \"docs\": itemgetter(\"docs\"),\n",
"}\n",
"# And now we put it all together!\n",
"final_chain = loaded_memory | standalone_question | retrieved_documents | answer\n"
"final_chain = loaded_memory | standalone_question | retrieved_documents | answer"
]
},
{
@@ -387,7 +412,7 @@
"source": [
"inputs = {\"question\": \"where did harrison work?\"}\n",
"result = final_chain.invoke(inputs)\n",
"result\n"
"result"
]
},
{
@@ -400,7 +425,7 @@
"# Note that the memory does not save automatically\n",
"# This will be improved in the future\n",
"# For now you need to save it yourself\n",
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})\n"
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})"
]
},
{
@@ -422,7 +447,7 @@
}
],
"source": [
"memory.load_memory_variables({})\n"
"memory.load_memory_variables({})"
]
}
],

View File

@@ -33,7 +33,7 @@
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n"
"prompt = ChatPromptTemplate.from_template(template)"
]
},
{
@@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase\n"
"from langchain.utilities import SQLDatabase"
]
},
{
@@ -61,7 +61,7 @@
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")\n"
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")"
]
},
{
@@ -72,7 +72,7 @@
"outputs": [],
"source": [
"def get_schema(_):\n",
" return db.get_table_info()\n"
" return db.get_table_info()"
]
},
{
@@ -83,7 +83,7 @@
"outputs": [],
"source": [
"def run_query(query):\n",
" return db.run(query)\n"
" return db.run(query)"
]
},
{
@@ -100,11 +100,11 @@
"model = ChatOpenAI()\n",
"\n",
"sql_response = (\n",
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
" )\n"
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
")"
]
},
{
@@ -125,7 +125,7 @@
}
],
"source": [
"sql_response.invoke({\"question\": \"How many employees are there?\"})\n"
"sql_response.invoke({\"question\": \"How many employees are there?\"})"
]
},
{
@@ -141,7 +141,7 @@
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_template(template)\n"
"prompt_response = ChatPromptTemplate.from_template(template)"
]
},
{
@@ -152,14 +152,14 @@
"outputs": [],
"source": [
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response) \n",
" RunnablePassthrough.assign(query=sql_response)\n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | prompt_response\n",
" | model\n",
")\n"
")"
]
},
{
@@ -180,7 +180,7 @@
}
],
"source": [
"full_chain.invoke({\"question\": \"How many employees are there?\"})\n"
"full_chain.invoke({\"question\": \"How many employees are there?\"})"
]
},
{