mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-24 16:37:46 +00:00
Rename RunnableMap to RunnableParallel (#11487)
- keep alias for RunnableMap - update docs to use RunnableParallel and RunnablePassthrough.assign <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
6a10e8ef31
commit
628cc4cce8
@ -17,9 +17,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from operator import itemgetter\n",
|
||||||
"from langchain.chat_models import ChatOpenAI\n",
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
"from langchain.memory import ConversationBufferMemory\n",
|
"from langchain.memory import ConversationBufferMemory\n",
|
||||||
"from langchain.schema.runnable import RunnableMap\n",
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||||
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = ChatOpenAI()\n",
|
"model = ChatOpenAI()\n",
|
||||||
@ -27,7 +28,7 @@
|
|||||||
" (\"system\", \"You are a helpful chatbot\"),\n",
|
" (\"system\", \"You are a helpful chatbot\"),\n",
|
||||||
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
||||||
" (\"human\", \"{input}\")\n",
|
" (\"human\", \"{input}\")\n",
|
||||||
"])"
|
"])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -37,7 +38,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"memory = ConversationBufferMemory(return_messages=True)"
|
"memory = ConversationBufferMemory(return_messages=True)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -58,7 +59,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"memory.load_memory_variables({})"
|
"memory.load_memory_variables({})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -68,13 +69,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chain = RunnableMap({\n",
|
"chain = RunnablePassthrough.assign(\n",
|
||||||
" \"input\": lambda x: x[\"input\"],\n",
|
" memory=memory.load_memory_variables | itemgetter(\"history\")\n",
|
||||||
" \"memory\": memory.load_memory_variables\n",
|
") | prompt | model\n"
|
||||||
"}) | {\n",
|
|
||||||
" \"input\": lambda x: x[\"input\"],\n",
|
|
||||||
" \"history\": lambda x: x[\"memory\"][\"history\"]\n",
|
|
||||||
"} | prompt | model"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,7 +94,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"inputs = {\"input\": \"hi im bob\"}\n",
|
"inputs = {\"input\": \"hi im bob\"}\n",
|
||||||
"response = chain.invoke(inputs)\n",
|
"response = chain.invoke(inputs)\n",
|
||||||
"response"
|
"response\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -107,7 +104,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"memory.save_context(inputs, {\"output\": response.content})"
|
"memory.save_context(inputs, {\"output\": response.content})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -129,7 +126,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"memory.load_memory_variables({})"
|
"memory.load_memory_variables({})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -152,7 +149,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"inputs = {\"input\": \"whats my name\"}\n",
|
"inputs = {\"input\": \"whats my name\"}\n",
|
||||||
"response = chain.invoke(inputs)\n",
|
"response = chain.invoke(inputs)\n",
|
||||||
"response"
|
"response\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
"---\n",
|
"---\n",
|
||||||
"sidebar_position: 0\n",
|
"sidebar_position: 0\n",
|
||||||
"title: Prompt + LLM\n",
|
"title: Prompt + LLM\n",
|
||||||
"---"
|
"---\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -47,7 +47,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")\n",
|
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")\n",
|
||||||
"model = ChatOpenAI()\n",
|
"model = ChatOpenAI()\n",
|
||||||
"chain = prompt | model"
|
"chain = prompt | model\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -68,7 +68,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"foo\": \"bears\"})"
|
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -94,7 +94,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chain = prompt | model.bind(stop=[\"\\n\"])"
|
"chain = prompt | model.bind(stop=[\"\\n\"])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -115,7 +115,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"foo\": \"bears\"})"
|
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -153,7 +153,7 @@
|
|||||||
" }\n",
|
" }\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)"
|
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -174,7 +174,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"foo\": \"bears\"}, config={})"
|
"chain.invoke({\"foo\": \"bears\"}, config={})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -196,7 +196,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chain = prompt | model | StrOutputParser()"
|
"chain = prompt | model | StrOutputParser()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -225,7 +225,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"foo\": \"bears\"})"
|
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -251,7 +251,7 @@
|
|||||||
" prompt \n",
|
" prompt \n",
|
||||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||||
" | JsonOutputFunctionsParser()\n",
|
" | JsonOutputFunctionsParser()\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -273,7 +273,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"foo\": \"bears\"})"
|
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -289,7 +289,7 @@
|
|||||||
" prompt \n",
|
" prompt \n",
|
||||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||||
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -310,7 +310,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"foo\": \"bears\"})"
|
"chain.invoke({\"foo\": \"bears\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -332,13 +332,13 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.schema.runnable import RunnableMap, RunnablePassthrough\n",
|
"from langchain.schema.runnable import RunnableMap, RunnablePassthrough\n",
|
||||||
"\n",
|
"\n",
|
||||||
"map_ = RunnableMap({\"foo\": RunnablePassthrough()})\n",
|
"map_ = RunnableMap(foo=RunnablePassthrough())\n",
|
||||||
"chain = (\n",
|
"chain = (\n",
|
||||||
" map_ \n",
|
" map_ \n",
|
||||||
" | prompt\n",
|
" | prompt\n",
|
||||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||||
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -359,7 +359,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke(\"bears\")"
|
"chain.invoke(\"bears\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -382,7 +382,7 @@
|
|||||||
" | prompt\n",
|
" | prompt\n",
|
||||||
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
||||||
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -403,7 +403,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke(\"bears\")"
|
"chain.invoke(\"bears\")\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
"---\n",
|
"---\n",
|
||||||
"sidebar_position: 1\n",
|
"sidebar_position: 1\n",
|
||||||
"title: RAG\n",
|
"title: RAG\n",
|
||||||
"---"
|
"---\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -26,7 +26,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install langchain openai faiss-cpu tiktoken"
|
"!pip install langchain openai faiss-cpu tiktoken\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -43,7 +43,7 @@
|
|||||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||||
"from langchain.schema.runnable import RunnablePassthrough\n",
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||||
"from langchain.vectorstores import FAISS"
|
"from langchain.vectorstores import FAISS\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -63,7 +63,7 @@
|
|||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"prompt = ChatPromptTemplate.from_template(template)\n",
|
"prompt = ChatPromptTemplate.from_template(template)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = ChatOpenAI()"
|
"model = ChatOpenAI()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -78,7 +78,7 @@
|
|||||||
" | prompt \n",
|
" | prompt \n",
|
||||||
" | model \n",
|
" | model \n",
|
||||||
" | StrOutputParser()\n",
|
" | StrOutputParser()\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -99,7 +99,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke(\"where did harrison work?\")"
|
"chain.invoke(\"where did harrison work?\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -122,7 +122,7 @@
|
|||||||
" \"context\": itemgetter(\"question\") | retriever, \n",
|
" \"context\": itemgetter(\"question\") | retriever, \n",
|
||||||
" \"question\": itemgetter(\"question\"), \n",
|
" \"question\": itemgetter(\"question\"), \n",
|
||||||
" \"language\": itemgetter(\"language\")\n",
|
" \"language\": itemgetter(\"language\")\n",
|
||||||
"} | prompt | model | StrOutputParser()"
|
"} | prompt | model | StrOutputParser()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -143,7 +143,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})"
|
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -164,7 +164,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.schema.runnable import RunnableMap\n",
|
"from langchain.schema.runnable import RunnableMap\n",
|
||||||
"from langchain.schema import format_document"
|
"from langchain.schema import format_document\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -182,7 +182,7 @@
|
|||||||
"{chat_history}\n",
|
"{chat_history}\n",
|
||||||
"Follow Up Input: {question}\n",
|
"Follow Up Input: {question}\n",
|
||||||
"Standalone question:\"\"\"\n",
|
"Standalone question:\"\"\"\n",
|
||||||
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)"
|
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -197,7 +197,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Question: {question}\n",
|
"Question: {question}\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)"
|
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -210,7 +210,7 @@
|
|||||||
"DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n",
|
"DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n",
|
||||||
"def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n",
|
"def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n",
|
||||||
" doc_strings = [format_document(doc, document_prompt) for doc in docs]\n",
|
" doc_strings = [format_document(doc, document_prompt) for doc in docs]\n",
|
||||||
" return document_separator.join(doc_strings)"
|
" return document_separator.join(doc_strings)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -227,7 +227,7 @@
|
|||||||
" human = \"Human: \" + dialogue_turn[0]\n",
|
" human = \"Human: \" + dialogue_turn[0]\n",
|
||||||
" ai = \"Assistant: \" + dialogue_turn[1]\n",
|
" ai = \"Assistant: \" + dialogue_turn[1]\n",
|
||||||
" buffer += \"\\n\" + \"\\n\".join([human, ai])\n",
|
" buffer += \"\\n\" + \"\\n\".join([human, ai])\n",
|
||||||
" return buffer"
|
" return buffer\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -238,18 +238,15 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"_inputs = RunnableMap(\n",
|
"_inputs = RunnableMap(\n",
|
||||||
" {\n",
|
" standalone_question=RunnablePassthrough.assign(\n",
|
||||||
" \"standalone_question\": {\n",
|
" chat_history=lambda x: _format_chat_history(x['chat_history'])\n",
|
||||||
" \"question\": lambda x: x[\"question\"],\n",
|
" ) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
|
||||||
" \"chat_history\": lambda x: _format_chat_history(x['chat_history'])\n",
|
|
||||||
" } | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
|
|
||||||
" }\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"_context = {\n",
|
"_context = {\n",
|
||||||
" \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n",
|
" \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n",
|
||||||
" \"question\": lambda x: x[\"standalone_question\"]\n",
|
" \"question\": lambda x: x[\"standalone_question\"]\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()"
|
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -273,7 +270,7 @@
|
|||||||
"conversational_qa_chain.invoke({\n",
|
"conversational_qa_chain.invoke({\n",
|
||||||
" \"question\": \"where did harrison work?\",\n",
|
" \"question\": \"where did harrison work?\",\n",
|
||||||
" \"chat_history\": [],\n",
|
" \"chat_history\": [],\n",
|
||||||
"})"
|
"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -297,7 +294,7 @@
|
|||||||
"conversational_qa_chain.invoke({\n",
|
"conversational_qa_chain.invoke({\n",
|
||||||
" \"question\": \"where did he work?\",\n",
|
" \"question\": \"where did he work?\",\n",
|
||||||
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
|
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
|
||||||
"})"
|
"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -317,7 +314,8 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.memory import ConversationBufferMemory"
|
"from operator import itemgetter\n",
|
||||||
|
"from langchain.memory import ConversationBufferMemory\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -327,7 +325,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")"
|
"memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -338,19 +336,10 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# First we add a step to load memory\n",
|
"# First we add a step to load memory\n",
|
||||||
"# This needs to be a RunnableMap because its the first input\n",
|
"# This adds a \"memory\" key to the input object\n",
|
||||||
"loaded_memory = RunnableMap(\n",
|
"loaded_memory = RunnablePassthrough.assign(\n",
|
||||||
" {\n",
|
" chat_history=memory.load_memory_variables | itemgetter(\"history\"),\n",
|
||||||
" \"question\": itemgetter(\"question\"),\n",
|
|
||||||
" \"memory\": memory.load_memory_variables,\n",
|
|
||||||
" }\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"# Next we add a step to expand memory into the variables\n",
|
|
||||||
"expanded_memory = {\n",
|
|
||||||
" \"question\": itemgetter(\"question\"),\n",
|
|
||||||
" \"chat_history\": lambda x: x[\"memory\"][\"history\"]\n",
|
|
||||||
"}\n",
|
|
||||||
"\n",
|
|
||||||
"# Now we calculate the standalone question\n",
|
"# Now we calculate the standalone question\n",
|
||||||
"standalone_question = {\n",
|
"standalone_question = {\n",
|
||||||
" \"standalone_question\": {\n",
|
" \"standalone_question\": {\n",
|
||||||
@ -374,7 +363,7 @@
|
|||||||
" \"docs\": itemgetter(\"docs\"),\n",
|
" \"docs\": itemgetter(\"docs\"),\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"# And now we put it all together!\n",
|
"# And now we put it all together!\n",
|
||||||
"final_chain = loaded_memory | expanded_memory | standalone_question | retrieved_documents | answer"
|
"final_chain = loaded_memory | expanded_memory | standalone_question | retrieved_documents | answer\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -398,7 +387,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"inputs = {\"question\": \"where did harrison work?\"}\n",
|
"inputs = {\"question\": \"where did harrison work?\"}\n",
|
||||||
"result = final_chain.invoke(inputs)\n",
|
"result = final_chain.invoke(inputs)\n",
|
||||||
"result"
|
"result\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -411,7 +400,7 @@
|
|||||||
"# Note that the memory does not save automatically\n",
|
"# Note that the memory does not save automatically\n",
|
||||||
"# This will be improved in the future\n",
|
"# This will be improved in the future\n",
|
||||||
"# For now you need to save it yourself\n",
|
"# For now you need to save it yourself\n",
|
||||||
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})"
|
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -433,7 +422,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"memory.load_memory_variables({})"
|
"memory.load_memory_variables({})\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
"---\n",
|
"---\n",
|
||||||
"sidebar_position: 3\n",
|
"sidebar_position: 3\n",
|
||||||
"title: Querying a SQL DB\n",
|
"title: Querying a SQL DB\n",
|
||||||
"---"
|
"---\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -33,7 +33,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Question: {question}\n",
|
"Question: {question}\n",
|
||||||
"SQL Query:\"\"\"\n",
|
"SQL Query:\"\"\"\n",
|
||||||
"prompt = ChatPromptTemplate.from_template(template)"
|
"prompt = ChatPromptTemplate.from_template(template)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -43,7 +43,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.utilities import SQLDatabase"
|
"from langchain.utilities import SQLDatabase\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -61,7 +61,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")"
|
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -72,7 +72,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def get_schema(_):\n",
|
"def get_schema(_):\n",
|
||||||
" return db.get_table_info()"
|
" return db.get_table_info()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -83,7 +83,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def run_query(query):\n",
|
"def run_query(query):\n",
|
||||||
" return db.run(query)"
|
" return db.run(query)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -93,24 +93,18 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from operator import itemgetter\n",
|
|
||||||
"\n",
|
|
||||||
"from langchain.chat_models import ChatOpenAI\n",
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||||
"from langchain.schema.runnable import RunnableLambda, RunnableMap\n",
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = ChatOpenAI()\n",
|
"model = ChatOpenAI()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"inputs = {\n",
|
|
||||||
" \"schema\": RunnableLambda(get_schema),\n",
|
|
||||||
" \"question\": itemgetter(\"question\")\n",
|
|
||||||
"}\n",
|
|
||||||
"sql_response = (\n",
|
"sql_response = (\n",
|
||||||
" RunnableMap(inputs)\n",
|
" RunnablePassthrough.assign(schema=get_schema)\n",
|
||||||
" | prompt\n",
|
" | prompt\n",
|
||||||
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
|
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
|
||||||
" | StrOutputParser()\n",
|
" | StrOutputParser()\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -131,7 +125,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"sql_response.invoke({\"question\": \"How many employees are there?\"})"
|
"sql_response.invoke({\"question\": \"How many employees are there?\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -147,7 +141,7 @@
|
|||||||
"Question: {question}\n",
|
"Question: {question}\n",
|
||||||
"SQL Query: {query}\n",
|
"SQL Query: {query}\n",
|
||||||
"SQL Response: {response}\"\"\"\n",
|
"SQL Response: {response}\"\"\"\n",
|
||||||
"prompt_response = ChatPromptTemplate.from_template(template)"
|
"prompt_response = ChatPromptTemplate.from_template(template)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -158,19 +152,14 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"full_chain = (\n",
|
"full_chain = (\n",
|
||||||
" RunnableMap({\n",
|
" RunnablePassthrough.assign(query=sql_response) \n",
|
||||||
" \"question\": itemgetter(\"question\"),\n",
|
" | RunnablePassthrough.assign(\n",
|
||||||
" \"query\": sql_response,\n",
|
" schema=get_schema,\n",
|
||||||
" }) \n",
|
" response=lambda x: db.run(x[\"query\"]),\n",
|
||||||
" | {\n",
|
" )\n",
|
||||||
" \"schema\": RunnableLambda(get_schema),\n",
|
|
||||||
" \"question\": itemgetter(\"question\"),\n",
|
|
||||||
" \"query\": itemgetter(\"query\"),\n",
|
|
||||||
" \"response\": lambda x: db.run(x[\"query\"]) \n",
|
|
||||||
" } \n",
|
|
||||||
" | prompt_response \n",
|
" | prompt_response \n",
|
||||||
" | model\n",
|
" | model\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -191,7 +180,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"full_chain.invoke({\"question\": \"How many employees are there?\"})"
|
"full_chain.invoke({\"question\": \"How many employees are there?\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -5,9 +5,9 @@
|
|||||||
"id": "b022ab74-794d-4c54-ad47-ff9549ddb9d2",
|
"id": "b022ab74-794d-4c54-ad47-ff9549ddb9d2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Use RunnableMaps\n",
|
"# Use RunnableParallel/RunnableMap\n",
|
||||||
"\n",
|
"\n",
|
||||||
"RunnableMaps make it easy to execute multiple Runnables in parallel, and to return the output of these Runnables as a map."
|
"RunnableParallel (aka. RunnableMap) makes it easy to execute multiple Runnables in parallel, and to return the output of these Runnables as a map."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -31,16 +31,16 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.chat_models import ChatOpenAI\n",
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
"from langchain.prompts import ChatPromptTemplate\n",
|
"from langchain.prompts import ChatPromptTemplate\n",
|
||||||
"from langchain.schema.runnable import RunnableMap\n",
|
"from langchain.schema.runnable import RunnableParallel\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = ChatOpenAI()\n",
|
"model = ChatOpenAI()\n",
|
||||||
"joke_chain = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
|
"joke_chain = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
|
||||||
"poem_chain = ChatPromptTemplate.from_template(\"write a 2-line poem about {topic}\") | model\n",
|
"poem_chain = ChatPromptTemplate.from_template(\"write a 2-line poem about {topic}\") | model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"map_chain = RunnableMap({\"joke\": joke_chain, \"poem\": poem_chain,})\n",
|
"map_chain = RunnableParallel(joke=joke_chain, poem=poem_chain)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"map_chain.invoke({\"topic\": \"bear\"})"
|
"map_chain.invoke({\"topic\": \"bear\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -91,7 +91,7 @@
|
|||||||
" | StrOutputParser()\n",
|
" | StrOutputParser()\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"retrieval_chain.invoke(\"where did harrison work?\")"
|
"retrieval_chain.invoke(\"where did harrison work?\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -131,7 +131,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"%%timeit\n",
|
"%%timeit\n",
|
||||||
"\n",
|
"\n",
|
||||||
"joke_chain.invoke({\"topic\": \"bear\"})"
|
"joke_chain.invoke({\"topic\": \"bear\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -151,7 +151,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"%%timeit\n",
|
"%%timeit\n",
|
||||||
"\n",
|
"\n",
|
||||||
"poem_chain.invoke({\"topic\": \"bear\"})"
|
"poem_chain.invoke({\"topic\": \"bear\"})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -171,7 +171,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"%%timeit\n",
|
"%%timeit\n",
|
||||||
"\n",
|
"\n",
|
||||||
"map_chain.invoke({\"topic\": \"bear\"})"
|
"map_chain.invoke({\"topic\": \"bear\"})\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -131,7 +131,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# The input schema of the chain is the input schema of its first part, the prompt.\n",
|
"# The input schema of the chain is the input schema of its first part, the prompt.\n",
|
||||||
"chain.input_schema.schema()"
|
"chain.input_schema.schema()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -244,7 +244,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# The output schema of the chain is the output schema of its last part, in this case a ChatModel, which outputs a ChatMessage\n",
|
"# The output schema of the chain is the output schema of its last part, in this case a ChatModel, which outputs a ChatMessage\n",
|
||||||
"chain.output_schema.schema()"
|
"chain.output_schema.schema()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -783,7 +783,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"async for chunk in retrieval_chain.astream_log(\"where did harrison work?\", include_names=['Docs'], diff=False):\n",
|
"async for chunk in retrieval_chain.astream_log(\"where did harrison work?\", include_names=['Docs'], diff=False):\n",
|
||||||
" print(chunk)"
|
" print(chunk)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -793,7 +793,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Parallelism\n",
|
"## Parallelism\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Let's take a look at how LangChain Expression Language support parallel requests as much as possible. For example, when using a RunnableMap (often written as a dictionary) it executes each element in parallel."
|
"Let's take a look at how LangChain Expression Language support parallel requests as much as possible. For example, when using a RunnableParallel (often written as a dictionary) it executes each element in parallel."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -803,13 +803,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.schema.runnable import RunnableMap\n",
|
"from langchain.schema.runnable import RunnableParallel\n",
|
||||||
"chain1 = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
|
"chain1 = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
|
||||||
"chain2 = ChatPromptTemplate.from_template(\"write a short (2 line) poem about {topic}\") | model\n",
|
"chain2 = ChatPromptTemplate.from_template(\"write a short (2 line) poem about {topic}\") | model\n",
|
||||||
"combined = RunnableMap({\n",
|
"combined = RunnableParallel(joke=chain1, poem=chain2)\n"
|
||||||
" \"joke\": chain1,\n",
|
|
||||||
" \"poem\": chain2,\n",
|
|
||||||
"})\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -27,7 +27,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.chat_models.fireworks import ChatFireworks\n",
|
"from langchain.chat_models.fireworks import ChatFireworks\n",
|
||||||
"from langchain.schema import SystemMessage, HumanMessage\n",
|
"from langchain.schema import SystemMessage, HumanMessage\n",
|
||||||
"import os"
|
"import os\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -56,7 +56,7 @@
|
|||||||
" os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
|
" os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Initialize a Fireworks chat model\n",
|
"# Initialize a Fireworks chat model\n",
|
||||||
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
|
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -116,7 +116,7 @@
|
|||||||
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n",
|
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n",
|
||||||
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
||||||
"human_message = HumanMessage(content=\"How's the weather today?\")\n",
|
"human_message = HumanMessage(content=\"How's the weather today?\")\n",
|
||||||
"chat([system_message, human_message])"
|
"chat([system_message, human_message])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -144,7 +144,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.chat_models import ChatFireworks\n",
|
"from langchain.chat_models import ChatFireworks\n",
|
||||||
"from langchain.memory import ConversationBufferMemory\n",
|
"from langchain.memory import ConversationBufferMemory\n",
|
||||||
"from langchain.schema.runnable import RunnableMap\n",
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||||
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":0, \"max_tokens\":64, \"top_p\":1.0})\n",
|
"llm = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":0, \"max_tokens\":64, \"top_p\":1.0})\n",
|
||||||
@ -152,7 +152,7 @@
|
|||||||
" (\"system\", \"You are a helpful chatbot that speaks like a pirate.\"),\n",
|
" (\"system\", \"You are a helpful chatbot that speaks like a pirate.\"),\n",
|
||||||
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
||||||
" (\"human\", \"{input}\")\n",
|
" (\"human\", \"{input}\")\n",
|
||||||
"])"
|
"])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -182,7 +182,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"memory = ConversationBufferMemory(return_messages=True)\n",
|
"memory = ConversationBufferMemory(return_messages=True)\n",
|
||||||
"memory.load_memory_variables({})"
|
"memory.load_memory_variables({})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -200,13 +200,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chain = RunnableMap({\n",
|
"chain = RunnablePassthrough.assign(\n",
|
||||||
" \"input\": lambda x: x[\"input\"],\n",
|
" history=memory.load_memory_variables | (lambda x: x[\"history\"])\n",
|
||||||
" \"memory\": memory.load_memory_variables\n",
|
") | prompt | llm.bind(stop=[\"\\n\\n\"])\n"
|
||||||
"}) | {\n",
|
|
||||||
" \"input\": lambda x: x[\"input\"],\n",
|
|
||||||
" \"history\": lambda x: x[\"memory\"][\"history\"]\n",
|
|
||||||
"} | prompt | llm.bind(stop=[\"\\n\\n\"])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -237,7 +233,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"inputs = {\"input\": \"hi im bob\"}\n",
|
"inputs = {\"input\": \"hi im bob\"}\n",
|
||||||
"response = chain.invoke(inputs)\n",
|
"response = chain.invoke(inputs)\n",
|
||||||
"response"
|
"response\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -268,7 +264,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"memory.save_context(inputs, {\"output\": response.content})\n",
|
"memory.save_context(inputs, {\"output\": response.content})\n",
|
||||||
"memory.load_memory_variables({})"
|
"memory.load_memory_variables({})\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -298,7 +294,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"inputs = {\"input\": \"whats my name\"}\n",
|
"inputs = {\"input\": \"whats my name\"}\n",
|
||||||
"chain.invoke(inputs)"
|
"chain.invoke(inputs)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# install the opaqueprompts and langchain packages\n",
|
"# install the opaqueprompts and langchain packages\n",
|
||||||
"! pip install opaqueprompts langchain"
|
"! pip install opaqueprompts langchain\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -40,7 +40,7 @@
|
|||||||
"# Set API keys\n",
|
"# Set API keys\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ['OPAQUEPROMPTS_API_KEY'] = \"<OPAQUEPROMPTS_API_KEY>\"\n",
|
"os.environ['OPAQUEPROMPTS_API_KEY'] = \"<OPAQUEPROMPTS_API_KEY>\"\n",
|
||||||
"os.environ['OPENAI_API_KEY'] = \"<OPENAI_API_KEY>\""
|
"os.environ['OPENAI_API_KEY'] = \"<OPENAI_API_KEY>\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -59,7 +59,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import langchain\n",
|
"import langchain\n",
|
||||||
"from langchain.chains import LLMChain\nfrom langchain.prompts import PromptTemplate\n",
|
"from langchain.chains import LLMChain\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
||||||
"from langchain.llms import OpenAI\n",
|
"from langchain.llms import OpenAI\n",
|
||||||
"from langchain.memory import ConversationBufferWindowMemory\n",
|
"from langchain.memory import ConversationBufferWindowMemory\n",
|
||||||
@ -117,7 +118,7 @@
|
|||||||
" {\"question\": \"\"\"Write a message to remind John to do password reset for his website to stay secure.\"\"\"},\n",
|
" {\"question\": \"\"\"Write a message to remind John to do password reset for his website to stay secure.\"\"\"},\n",
|
||||||
" callbacks=[StdOutCallbackHandler()],\n",
|
" callbacks=[StdOutCallbackHandler()],\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -173,7 +174,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import langchain.utilities.opaqueprompts as op\n",
|
"import langchain.utilities.opaqueprompts as op\n",
|
||||||
"from langchain.schema.runnable import RunnableMap\n",
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -181,19 +182,16 @@
|
|||||||
"llm = OpenAI()\n",
|
"llm = OpenAI()\n",
|
||||||
"pg_chain = (\n",
|
"pg_chain = (\n",
|
||||||
" op.sanitize\n",
|
" op.sanitize\n",
|
||||||
" | RunnableMap(\n",
|
" | RunnablePassthrough.assign(\n",
|
||||||
" {\n",
|
" response=(lambda x: x[\"sanitized_input\"])\n",
|
||||||
" \"response\": (lambda x: x[\"sanitized_input\"])\n",
|
|
||||||
" | prompt\n",
|
" | prompt\n",
|
||||||
" | llm\n",
|
" | llm\n",
|
||||||
" | StrOutputParser(),\n",
|
" | StrOutputParser(),\n",
|
||||||
" \"secure_context\": lambda x: x[\"secure_context\"],\n",
|
|
||||||
" }\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
" | (lambda x: op.desanitize(x[\"response\"], x[\"secure_context\"]))\n",
|
" | (lambda x: op.desanitize(x[\"response\"], x[\"secure_context\"]))\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pg_chain.invoke({\"question\": \"Write a text message to remind John to do password reset for his website through his email to stay secure.\", \"history\": \"\"})"
|
"pg_chain.invoke({\"question\": \"Write a text message to remind John to do password reset for his website through his email to stay secure.\", \"history\": \"\"})\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -4,7 +4,7 @@ from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
|||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.output_parser import NoOpOutputParser
|
from langchain.schema.output_parser import NoOpOutputParser
|
||||||
from langchain.schema.prompt_template import BasePromptTemplate
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
from langchain.schema.runnable import RunnableMap, RunnableSequence
|
from langchain.schema.runnable import RunnableParallel, RunnableSequence
|
||||||
from langchain.utilities.sql_database import SQLDatabase
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ def create_sql_query_chain(
|
|||||||
if "dialect" in prompt_to_use.input_variables:
|
if "dialect" in prompt_to_use.input_variables:
|
||||||
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
|
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
|
||||||
return (
|
return (
|
||||||
RunnableMap(inputs)
|
RunnableParallel(inputs)
|
||||||
| prompt_to_use
|
| prompt_to_use
|
||||||
| llm.bind(stop=["\nSQLResult:"])
|
| llm.bind(stop=["\nSQLResult:"])
|
||||||
| NoOpOutputParser()
|
| NoOpOutputParser()
|
||||||
|
@ -5,6 +5,7 @@ from langchain.schema.runnable.base import (
|
|||||||
RunnableGenerator,
|
RunnableGenerator,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
|
RunnableParallel,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
)
|
)
|
||||||
@ -30,6 +31,7 @@ __all__ = [
|
|||||||
"RunnableGenerator",
|
"RunnableGenerator",
|
||||||
"RunnableLambda",
|
"RunnableLambda",
|
||||||
"RunnableMap",
|
"RunnableMap",
|
||||||
|
"RunnableParallel",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
"RunnableSequence",
|
"RunnableSequence",
|
||||||
"RunnableWithFallbacks",
|
"RunnableWithFallbacks",
|
||||||
|
@ -1490,7 +1490,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that runs a mapping of runnables in parallel,
|
A runnable that runs a mapping of runnables in parallel,
|
||||||
and returns a mapping of their outputs.
|
and returns a mapping of their outputs.
|
||||||
@ -1500,16 +1500,27 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
steps: Mapping[
|
__steps: Optional[
|
||||||
str,
|
Mapping[
|
||||||
Union[
|
str,
|
||||||
Runnable[Input, Any],
|
Union[
|
||||||
Callable[[Input], Any],
|
Runnable[Input, Any],
|
||||||
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
Callable[[Input], Any],
|
||||||
],
|
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
**kwargs: Union[
|
||||||
|
Runnable[Input, Any],
|
||||||
|
Callable[[Input], Any],
|
||||||
|
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
||||||
],
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()})
|
merged = {**__steps} if __steps is not None else {}
|
||||||
|
merged.update(kwargs)
|
||||||
|
super().__init__(
|
||||||
|
steps={key: coerce_to_runnable(r) for key, r in merged.items()}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
@ -1538,7 +1549,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
):
|
):
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableMapInput",
|
"RunnableParallelInput",
|
||||||
**{
|
**{
|
||||||
k: (v.annotation, v.default)
|
k: (v.annotation, v.default)
|
||||||
for step in self.steps.values()
|
for step in self.steps.values()
|
||||||
@ -1553,7 +1564,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableMapOutput",
|
"RunnableParallelOutput",
|
||||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1797,6 +1808,10 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
# We support both names
|
||||||
|
RunnableMap = RunnableParallel
|
||||||
|
|
||||||
|
|
||||||
class RunnableGenerator(Runnable[Input, Output]):
|
class RunnableGenerator(Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that runs a generator function.
|
A runnable that runs a generator function.
|
||||||
@ -2435,10 +2450,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
|||||||
elif callable(thing):
|
elif callable(thing):
|
||||||
return RunnableLambda(cast(Callable[[Input], Output], thing))
|
return RunnableLambda(cast(Callable[[Input], Output], thing))
|
||||||
elif isinstance(thing, dict):
|
elif isinstance(thing, dict):
|
||||||
runnables: Mapping[str, Runnable[Any, Any]] = {
|
return cast(Runnable[Input, Output], RunnableParallel(thing))
|
||||||
key: coerce_to_runnable(r) for key, r in thing.items()
|
|
||||||
}
|
|
||||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected a Runnable, callable or dict."
|
f"Expected a Runnable, callable or dict."
|
||||||
|
@ -21,7 +21,7 @@ from langchain.pydantic_v1 import BaseModel, create_model
|
|||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Input,
|
Input,
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableMap,
|
RunnableParallel,
|
||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||||
@ -83,7 +83,7 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
|||||||
A runnable that merges the Dict input with the output produced by the
|
A runnable that merges the Dict input with the output produced by the
|
||||||
mapping argument.
|
mapping argument.
|
||||||
"""
|
"""
|
||||||
return RunnableAssign(RunnableMap(kwargs))
|
return RunnableAssign(RunnableParallel(kwargs))
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
return self._call_with_config(identity, input, config)
|
return self._call_with_config(identity, input, config)
|
||||||
@ -119,9 +119,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mapper: RunnableMap[Dict[str, Any]]
|
mapper: RunnableParallel[Dict[str, Any]]
|
||||||
|
|
||||||
def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None:
|
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
super().__init__(mapper=mapper, **kwargs)
|
super().__init__(mapper=mapper, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -5,7 +5,7 @@ from langchain.llms.opaqueprompts import OpaquePrompts
|
|||||||
from langchain.memory import ConversationBufferWindowMemory
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.schema.output_parser import StrOutputParser
|
from langchain.schema.output_parser import StrOutputParser
|
||||||
from langchain.schema.runnable import RunnableMap
|
from langchain.schema.runnable import RunnableParallel
|
||||||
|
|
||||||
prompt_template = """
|
prompt_template = """
|
||||||
As an AI assistant, you will answer questions according to given context.
|
As an AI assistant, you will answer questions according to given context.
|
||||||
@ -64,14 +64,12 @@ def test_opaqueprompts_functions() -> None:
|
|||||||
llm = OpenAI()
|
llm = OpenAI()
|
||||||
pg_chain = (
|
pg_chain = (
|
||||||
op.sanitize
|
op.sanitize
|
||||||
| RunnableMap(
|
| RunnableParallel(
|
||||||
{
|
secure_context=lambda x: x["secure_context"], # type: ignore
|
||||||
"response": (lambda x: x["sanitized_input"]) # type: ignore
|
response=(lambda x: x["sanitized_input"]) # type: ignore
|
||||||
| prompt
|
| prompt
|
||||||
| llm
|
| llm
|
||||||
| StrOutputParser(),
|
| StrOutputParser(),
|
||||||
"secure_context": lambda x: x["secure_context"],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
| (lambda x: op.desanitize(x["response"], x["secure_context"]))
|
| (lambda x: op.desanitize(x["response"], x["secure_context"]))
|
||||||
)
|
)
|
||||||
|
@ -629,7 +629,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -643,7 +643,7 @@
|
|||||||
"base",
|
"base",
|
||||||
"RunnableLambda"
|
"RunnableLambda"
|
||||||
],
|
],
|
||||||
"repr": "RunnableLambda(...)"
|
"repr": "RunnableLambda(lambda x: x['key'])"
|
||||||
},
|
},
|
||||||
"input": {
|
"input": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
@ -652,7 +652,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -666,7 +666,7 @@
|
|||||||
"base",
|
"base",
|
||||||
"RunnableLambda"
|
"RunnableLambda"
|
||||||
],
|
],
|
||||||
"repr": "RunnableLambda(...)"
|
"repr": "RunnableLambda(lambda x: x['question'])"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -709,7 +709,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -1438,7 +1438,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -1461,7 +1461,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -3455,7 +3455,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -3769,7 +3769,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
@ -3910,7 +3910,7 @@
|
|||||||
"langchain",
|
"langchain",
|
||||||
"schema",
|
"schema",
|
||||||
"runnable",
|
"runnable",
|
||||||
"RunnableMap"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps": {
|
||||||
|
@ -53,7 +53,7 @@ from langchain.schema.runnable import (
|
|||||||
RunnableBranch,
|
RunnableBranch,
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableParallel,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
@ -491,7 +491,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||||
}
|
}
|
||||||
assert seq_w_map.output_schema.schema() == {
|
assert seq_w_map.output_schema.schema() == {
|
||||||
"title": "RunnableMapOutput",
|
"title": "RunnableParallelOutput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"original": {"title": "Original", "type": "string"},
|
"original": {"title": "Original", "type": "string"},
|
||||||
@ -593,7 +593,7 @@ def test_schema_complex_seq() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert chain2.input_schema.schema() == {
|
assert chain2.input_schema.schema() == {
|
||||||
"title": "RunnableMapInput",
|
"title": "RunnableParallelInput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"person": {"title": "Person", "type": "string"},
|
"person": {"title": "Person", "type": "string"},
|
||||||
@ -1656,12 +1656,12 @@ async def test_stream_log_retriever() -> None:
|
|||||||
RunLogPatch(
|
RunLogPatch(
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
"path": "/logs/RunnableMap",
|
"path": "/logs/RunnableParallel",
|
||||||
"value": {
|
"value": {
|
||||||
"end_time": None,
|
"end_time": None,
|
||||||
"final_output": None,
|
"final_output": None,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "RunnableMap",
|
"name": "RunnableParallel",
|
||||||
"start_time": "2023-01-01T00:00:00.000",
|
"start_time": "2023-01-01T00:00:00.000",
|
||||||
"streamed_output_str": [],
|
"streamed_output_str": [],
|
||||||
"tags": ["seq:step:1"],
|
"tags": ["seq:step:1"],
|
||||||
@ -1733,7 +1733,7 @@ async def test_stream_log_retriever() -> None:
|
|||||||
RunLogPatch(
|
RunLogPatch(
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
"path": "/logs/RunnableMap/final_output",
|
"path": "/logs/RunnableParallel/final_output",
|
||||||
"value": {
|
"value": {
|
||||||
"documents": [
|
"documents": [
|
||||||
Document(page_content="foo"),
|
Document(page_content="foo"),
|
||||||
@ -1744,7 +1744,7 @@ async def test_stream_log_retriever() -> None:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
"path": "/logs/RunnableMap/end_time",
|
"path": "/logs/RunnableParallel/end_time",
|
||||||
"value": "2023-01-01T00:00:00.000",
|
"value": "2023-01-01T00:00:00.000",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -1792,8 +1792,8 @@ async def test_stream_log_retriever() -> None:
|
|||||||
"FakeListLLM:2",
|
"FakeListLLM:2",
|
||||||
"Retriever",
|
"Retriever",
|
||||||
"RunnableLambda",
|
"RunnableLambda",
|
||||||
"RunnableMap",
|
"RunnableParallel",
|
||||||
"RunnableMap:2",
|
"RunnableParallel:2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -1977,7 +1977,7 @@ Question:
|
|||||||
|
|
||||||
assert repr(chain) == snapshot
|
assert repr(chain) == snapshot
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert isinstance(chain.first, RunnableMap)
|
assert isinstance(chain.first, RunnableParallel)
|
||||||
assert chain.middle == [prompt, chat]
|
assert chain.middle == [prompt, chat]
|
||||||
assert chain.last == parser
|
assert chain.last == parser
|
||||||
assert dumps(chain, pretty=True) == snapshot
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
@ -2013,7 +2013,7 @@ What is your name?"""
|
|||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 4
|
assert len(parent_run.child_runs) == 4
|
||||||
map_run = parent_run.child_runs[0]
|
map_run = parent_run.child_runs[0]
|
||||||
assert map_run.name == "RunnableMap"
|
assert map_run.name == "RunnableParallel"
|
||||||
assert len(map_run.child_runs) == 3
|
assert len(map_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
@ -2043,7 +2043,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert chain.first == prompt
|
assert chain.first == prompt
|
||||||
assert chain.middle == [RunnableLambda(passthrough)]
|
assert chain.middle == [RunnableLambda(passthrough)]
|
||||||
assert isinstance(chain.last, RunnableMap)
|
assert isinstance(chain.last, RunnableParallel)
|
||||||
assert dumps(chain, pretty=True) == snapshot
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
# Test invoke
|
# Test invoke
|
||||||
@ -2074,7 +2074,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 3
|
assert len(parent_run.child_runs) == 3
|
||||||
map_run = parent_run.child_runs[2]
|
map_run = parent_run.child_runs[2]
|
||||||
assert map_run.name == "RunnableMap"
|
assert map_run.name == "RunnableParallel"
|
||||||
assert len(map_run.child_runs) == 2
|
assert len(map_run.child_runs) == 2
|
||||||
|
|
||||||
|
|
||||||
@ -2142,11 +2142,9 @@ async def test_higher_order_lambda_runnable(
|
|||||||
english_chain = ChatPromptTemplate.from_template(
|
english_chain = ChatPromptTemplate.from_template(
|
||||||
"You are an english major. Answer the question: {question}"
|
"You are an english major. Answer the question: {question}"
|
||||||
) | FakeListLLM(responses=["2"])
|
) | FakeListLLM(responses=["2"])
|
||||||
input_map: Runnable = RunnableMap(
|
input_map: Runnable = RunnableParallel(
|
||||||
{ # type: ignore[arg-type]
|
key=lambda x: x["key"],
|
||||||
"key": lambda x: x["key"],
|
input={"question": lambda x: x["question"]},
|
||||||
"input": {"question": lambda x: x["question"]},
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def router(input: Dict[str, Any]) -> Runnable:
|
def router(input: Dict[str, Any]) -> Runnable:
|
||||||
@ -2158,7 +2156,8 @@ async def test_higher_order_lambda_runnable(
|
|||||||
raise ValueError(f"Unknown key: {input['key']}")
|
raise ValueError(f"Unknown key: {input['key']}")
|
||||||
|
|
||||||
chain: Runnable = input_map | router
|
chain: Runnable = input_map | router
|
||||||
assert dumps(chain, pretty=True) == snapshot
|
if sys.version_info >= (3, 9):
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||||
assert result == "4"
|
assert result == "4"
|
||||||
@ -2256,7 +2255,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert chain.first == prompt
|
assert chain.first == prompt
|
||||||
assert chain.middle == [RunnableLambda(passthrough)]
|
assert chain.middle == [RunnableLambda(passthrough)]
|
||||||
assert isinstance(chain.last, RunnableMap)
|
assert isinstance(chain.last, RunnableParallel)
|
||||||
assert dumps(chain, pretty=True) == snapshot
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
# Test invoke
|
# Test invoke
|
||||||
@ -2293,7 +2292,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 3
|
assert len(parent_run.child_runs) == 3
|
||||||
map_run = parent_run.child_runs[2]
|
map_run = parent_run.child_runs[2]
|
||||||
assert map_run.name == "RunnableMap"
|
assert map_run.name == "RunnableParallel"
|
||||||
assert len(map_run.child_runs) == 3
|
assert len(map_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
@ -2460,12 +2459,12 @@ async def test_map_astream() -> None:
|
|||||||
assert final_state.state["logs"]["ChatPromptTemplate"][
|
assert final_state.state["logs"]["ChatPromptTemplate"][
|
||||||
"final_output"
|
"final_output"
|
||||||
] == prompt.invoke({"question": "What is your name?"})
|
] == prompt.invoke({"question": "What is your name?"})
|
||||||
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
|
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
|
||||||
assert sorted(final_state.state["logs"]) == [
|
assert sorted(final_state.state["logs"]) == [
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"FakeListChatModel",
|
"FakeListChatModel",
|
||||||
"FakeStreamingListLLM",
|
"FakeStreamingListLLM",
|
||||||
"RunnableMap",
|
"RunnableParallel",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2505,11 +2504,11 @@ async def test_map_astream() -> None:
|
|||||||
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
|
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
|
||||||
prompt.invoke({"question": "What is your name?"})
|
prompt.invoke({"question": "What is your name?"})
|
||||||
)
|
)
|
||||||
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
|
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
|
||||||
assert sorted(final_state.state["logs"]) == [
|
assert sorted(final_state.state["logs"]) == [
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"FakeStreamingListLLM",
|
"FakeStreamingListLLM",
|
||||||
"RunnableMap",
|
"RunnableParallel",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2910,7 +2909,7 @@ def llm_chain_with_fallbacks() -> RunnableSequence:
|
|||||||
pass_llm = FakeListLLM(responses=["bar"])
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||||
return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||||
[prompt | pass_llm]
|
[prompt | pass_llm]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user