Rename RunnableMap to RunnableParallel (#11487)

- keep alias for RunnableMap
- update docs to use RunnableParallel and RunnablePassthrough.assign

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-10-09 11:22:03 +01:00 committed by GitHub
parent 6a10e8ef31
commit 628cc4cce8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 190 additions and 213 deletions

View File

@ -17,9 +17,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from operator import itemgetter\n",
"from langchain.chat_models import ChatOpenAI\n", "from langchain.chat_models import ChatOpenAI\n",
"from langchain.memory import ConversationBufferMemory\n", "from langchain.memory import ConversationBufferMemory\n",
"from langchain.schema.runnable import RunnableMap\n", "from langchain.schema.runnable import RunnablePassthrough\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n", "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"\n", "\n",
"model = ChatOpenAI()\n", "model = ChatOpenAI()\n",
@ -27,7 +28,7 @@
" (\"system\", \"You are a helpful chatbot\"),\n", " (\"system\", \"You are a helpful chatbot\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n", " MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", \"{input}\")\n", " (\"human\", \"{input}\")\n",
"])" "])\n"
] ]
}, },
{ {
@ -37,7 +38,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"memory = ConversationBufferMemory(return_messages=True)" "memory = ConversationBufferMemory(return_messages=True)\n"
] ]
}, },
{ {
@ -58,7 +59,7 @@
} }
], ],
"source": [ "source": [
"memory.load_memory_variables({})" "memory.load_memory_variables({})\n"
] ]
}, },
{ {
@ -68,13 +69,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"chain = RunnableMap({\n", "chain = RunnablePassthrough.assign(\n",
" \"input\": lambda x: x[\"input\"],\n", " memory=memory.load_memory_variables | itemgetter(\"history\")\n",
" \"memory\": memory.load_memory_variables\n", ") | prompt | model\n"
"}) | {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"history\": lambda x: x[\"memory\"][\"history\"]\n",
"} | prompt | model"
] ]
}, },
{ {
@ -97,7 +94,7 @@
"source": [ "source": [
"inputs = {\"input\": \"hi im bob\"}\n", "inputs = {\"input\": \"hi im bob\"}\n",
"response = chain.invoke(inputs)\n", "response = chain.invoke(inputs)\n",
"response" "response\n"
] ]
}, },
{ {
@ -107,7 +104,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"memory.save_context(inputs, {\"output\": response.content})" "memory.save_context(inputs, {\"output\": response.content})\n"
] ]
}, },
{ {
@ -129,7 +126,7 @@
} }
], ],
"source": [ "source": [
"memory.load_memory_variables({})" "memory.load_memory_variables({})\n"
] ]
}, },
{ {
@ -152,7 +149,7 @@
"source": [ "source": [
"inputs = {\"input\": \"whats my name\"}\n", "inputs = {\"input\": \"whats my name\"}\n",
"response = chain.invoke(inputs)\n", "response = chain.invoke(inputs)\n",
"response" "response\n"
] ]
} }
], ],

View File

@ -8,7 +8,7 @@
"---\n", "---\n",
"sidebar_position: 0\n", "sidebar_position: 0\n",
"title: Prompt + LLM\n", "title: Prompt + LLM\n",
"---" "---\n"
] ]
}, },
{ {
@ -47,7 +47,7 @@
"\n", "\n",
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")\n", "prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")\n",
"model = ChatOpenAI()\n", "model = ChatOpenAI()\n",
"chain = prompt | model" "chain = prompt | model\n"
] ]
}, },
{ {
@ -68,7 +68,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"foo\": \"bears\"})" "chain.invoke({\"foo\": \"bears\"})\n"
] ]
}, },
{ {
@ -94,7 +94,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"chain = prompt | model.bind(stop=[\"\\n\"])" "chain = prompt | model.bind(stop=[\"\\n\"])\n"
] ]
}, },
{ {
@ -115,7 +115,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"foo\": \"bears\"})" "chain.invoke({\"foo\": \"bears\"})\n"
] ]
}, },
{ {
@ -153,7 +153,7 @@
" }\n", " }\n",
" }\n", " }\n",
" ]\n", " ]\n",
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)" "chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)\n"
] ]
}, },
{ {
@ -174,7 +174,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"foo\": \"bears\"}, config={})" "chain.invoke({\"foo\": \"bears\"}, config={})\n"
] ]
}, },
{ {
@ -196,7 +196,7 @@
"source": [ "source": [
"from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.output_parser import StrOutputParser\n",
"\n", "\n",
"chain = prompt | model | StrOutputParser()" "chain = prompt | model | StrOutputParser()\n"
] ]
}, },
{ {
@ -225,7 +225,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"foo\": \"bears\"})" "chain.invoke({\"foo\": \"bears\"})\n"
] ]
}, },
{ {
@ -251,7 +251,7 @@
" prompt \n", " prompt \n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n", " | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | JsonOutputFunctionsParser()\n", " | JsonOutputFunctionsParser()\n",
")" ")\n"
] ]
}, },
{ {
@ -273,7 +273,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"foo\": \"bears\"})" "chain.invoke({\"foo\": \"bears\"})\n"
] ]
}, },
{ {
@ -289,7 +289,7 @@
" prompt \n", " prompt \n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n", " | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n", " | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")" ")\n"
] ]
}, },
{ {
@ -310,7 +310,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"foo\": \"bears\"})" "chain.invoke({\"foo\": \"bears\"})\n"
] ]
}, },
{ {
@ -332,13 +332,13 @@
"source": [ "source": [
"from langchain.schema.runnable import RunnableMap, RunnablePassthrough\n", "from langchain.schema.runnable import RunnableMap, RunnablePassthrough\n",
"\n", "\n",
"map_ = RunnableMap({\"foo\": RunnablePassthrough()})\n", "map_ = RunnableMap(foo=RunnablePassthrough())\n",
"chain = (\n", "chain = (\n",
" map_ \n", " map_ \n",
" | prompt\n", " | prompt\n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n", " | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n", " | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")" ")\n"
] ]
}, },
{ {
@ -359,7 +359,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke(\"bears\")" "chain.invoke(\"bears\")\n"
] ]
}, },
{ {
@ -382,7 +382,7 @@
" | prompt\n", " | prompt\n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n", " | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n", " | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")" ")\n"
] ]
}, },
{ {
@ -403,7 +403,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke(\"bears\")" "chain.invoke(\"bears\")\n"
] ]
} }
], ],

View File

@ -8,7 +8,7 @@
"---\n", "---\n",
"sidebar_position: 1\n", "sidebar_position: 1\n",
"title: RAG\n", "title: RAG\n",
"---" "---\n"
] ]
}, },
{ {
@ -26,7 +26,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install langchain openai faiss-cpu tiktoken" "!pip install langchain openai faiss-cpu tiktoken\n"
] ]
}, },
{ {
@ -43,7 +43,7 @@
"from langchain.embeddings import OpenAIEmbeddings\n", "from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n", "from langchain.schema.runnable import RunnablePassthrough\n",
"from langchain.vectorstores import FAISS" "from langchain.vectorstores import FAISS\n"
] ]
}, },
{ {
@ -63,7 +63,7 @@
"\"\"\"\n", "\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n", "prompt = ChatPromptTemplate.from_template(template)\n",
"\n", "\n",
"model = ChatOpenAI()" "model = ChatOpenAI()\n"
] ]
}, },
{ {
@ -78,7 +78,7 @@
" | prompt \n", " | prompt \n",
" | model \n", " | model \n",
" | StrOutputParser()\n", " | StrOutputParser()\n",
")" ")\n"
] ]
}, },
{ {
@ -99,7 +99,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke(\"where did harrison work?\")" "chain.invoke(\"where did harrison work?\")\n"
] ]
}, },
{ {
@ -122,7 +122,7 @@
" \"context\": itemgetter(\"question\") | retriever, \n", " \"context\": itemgetter(\"question\") | retriever, \n",
" \"question\": itemgetter(\"question\"), \n", " \"question\": itemgetter(\"question\"), \n",
" \"language\": itemgetter(\"language\")\n", " \"language\": itemgetter(\"language\")\n",
"} | prompt | model | StrOutputParser()" "} | prompt | model | StrOutputParser()\n"
] ]
}, },
{ {
@ -143,7 +143,7 @@
} }
], ],
"source": [ "source": [
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})" "chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})\n"
] ]
}, },
{ {
@ -164,7 +164,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.schema.runnable import RunnableMap\n", "from langchain.schema.runnable import RunnableMap\n",
"from langchain.schema import format_document" "from langchain.schema import format_document\n"
] ]
}, },
{ {
@ -182,7 +182,7 @@
"{chat_history}\n", "{chat_history}\n",
"Follow Up Input: {question}\n", "Follow Up Input: {question}\n",
"Standalone question:\"\"\"\n", "Standalone question:\"\"\"\n",
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)" "CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n"
] ]
}, },
{ {
@ -197,7 +197,7 @@
"\n", "\n",
"Question: {question}\n", "Question: {question}\n",
"\"\"\"\n", "\"\"\"\n",
"ANSWER_PROMPT = ChatPromptTemplate.from_template(template)" "ANSWER_PROMPT = ChatPromptTemplate.from_template(template)\n"
] ]
}, },
{ {
@ -210,7 +210,7 @@
"DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n", "DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n",
"def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n", "def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n",
" doc_strings = [format_document(doc, document_prompt) for doc in docs]\n", " doc_strings = [format_document(doc, document_prompt) for doc in docs]\n",
" return document_separator.join(doc_strings)" " return document_separator.join(doc_strings)\n"
] ]
}, },
{ {
@ -227,7 +227,7 @@
" human = \"Human: \" + dialogue_turn[0]\n", " human = \"Human: \" + dialogue_turn[0]\n",
" ai = \"Assistant: \" + dialogue_turn[1]\n", " ai = \"Assistant: \" + dialogue_turn[1]\n",
" buffer += \"\\n\" + \"\\n\".join([human, ai])\n", " buffer += \"\\n\" + \"\\n\".join([human, ai])\n",
" return buffer" " return buffer\n"
] ]
}, },
{ {
@ -238,18 +238,15 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"_inputs = RunnableMap(\n", "_inputs = RunnableMap(\n",
" {\n", " standalone_question=RunnablePassthrough.assign(\n",
" \"standalone_question\": {\n", " chat_history=lambda x: _format_chat_history(x['chat_history'])\n",
" \"question\": lambda x: x[\"question\"],\n", " ) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
" \"chat_history\": lambda x: _format_chat_history(x['chat_history'])\n",
" } | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
" }\n",
")\n", ")\n",
"_context = {\n", "_context = {\n",
" \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n", " \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n",
" \"question\": lambda x: x[\"standalone_question\"]\n", " \"question\": lambda x: x[\"standalone_question\"]\n",
"}\n", "}\n",
"conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()" "conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()\n"
] ]
}, },
{ {
@ -273,7 +270,7 @@
"conversational_qa_chain.invoke({\n", "conversational_qa_chain.invoke({\n",
" \"question\": \"where did harrison work?\",\n", " \"question\": \"where did harrison work?\",\n",
" \"chat_history\": [],\n", " \"chat_history\": [],\n",
"})" "})\n"
] ]
}, },
{ {
@ -297,7 +294,7 @@
"conversational_qa_chain.invoke({\n", "conversational_qa_chain.invoke({\n",
" \"question\": \"where did he work?\",\n", " \"question\": \"where did he work?\",\n",
" \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n", " \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
"})" "})\n"
] ]
}, },
{ {
@ -317,7 +314,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.memory import ConversationBufferMemory" "from operator import itemgetter\n",
"from langchain.memory import ConversationBufferMemory\n"
] ]
}, },
{ {
@ -327,7 +325,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")" "memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")\n"
] ]
}, },
{ {
@ -338,19 +336,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# First we add a step to load memory\n", "# First we add a step to load memory\n",
"# This needs to be a RunnableMap because its the first input\n", "# This adds a \"memory\" key to the input object\n",
"loaded_memory = RunnableMap(\n", "loaded_memory = RunnablePassthrough.assign(\n",
" {\n", " chat_history=memory.load_memory_variables | itemgetter(\"history\"),\n",
" \"question\": itemgetter(\"question\"),\n",
" \"memory\": memory.load_memory_variables,\n",
" }\n",
")\n", ")\n",
"# Next we add a step to expand memory into the variables\n",
"expanded_memory = {\n",
" \"question\": itemgetter(\"question\"),\n",
" \"chat_history\": lambda x: x[\"memory\"][\"history\"]\n",
"}\n",
"\n",
"# Now we calculate the standalone question\n", "# Now we calculate the standalone question\n",
"standalone_question = {\n", "standalone_question = {\n",
" \"standalone_question\": {\n", " \"standalone_question\": {\n",
@ -374,7 +363,7 @@
" \"docs\": itemgetter(\"docs\"),\n", " \"docs\": itemgetter(\"docs\"),\n",
"}\n", "}\n",
"# And now we put it all together!\n", "# And now we put it all together!\n",
"final_chain = loaded_memory | expanded_memory | standalone_question | retrieved_documents | answer" "final_chain = loaded_memory | expanded_memory | standalone_question | retrieved_documents | answer\n"
] ]
}, },
{ {
@ -398,7 +387,7 @@
"source": [ "source": [
"inputs = {\"question\": \"where did harrison work?\"}\n", "inputs = {\"question\": \"where did harrison work?\"}\n",
"result = final_chain.invoke(inputs)\n", "result = final_chain.invoke(inputs)\n",
"result" "result\n"
] ]
}, },
{ {
@ -411,7 +400,7 @@
"# Note that the memory does not save automatically\n", "# Note that the memory does not save automatically\n",
"# This will be improved in the future\n", "# This will be improved in the future\n",
"# For now you need to save it yourself\n", "# For now you need to save it yourself\n",
"memory.save_context(inputs, {\"answer\": result[\"answer\"].content})" "memory.save_context(inputs, {\"answer\": result[\"answer\"].content})\n"
] ]
}, },
{ {
@ -433,7 +422,7 @@
} }
], ],
"source": [ "source": [
"memory.load_memory_variables({})" "memory.load_memory_variables({})\n"
] ]
} }
], ],

View File

@ -8,7 +8,7 @@
"---\n", "---\n",
"sidebar_position: 3\n", "sidebar_position: 3\n",
"title: Querying a SQL DB\n", "title: Querying a SQL DB\n",
"---" "---\n"
] ]
}, },
{ {
@ -33,7 +33,7 @@
"\n", "\n",
"Question: {question}\n", "Question: {question}\n",
"SQL Query:\"\"\"\n", "SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)" "prompt = ChatPromptTemplate.from_template(template)\n"
] ]
}, },
{ {
@ -43,7 +43,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.utilities import SQLDatabase" "from langchain.utilities import SQLDatabase\n"
] ]
}, },
{ {
@ -61,7 +61,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")" "db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")\n"
] ]
}, },
{ {
@ -72,7 +72,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def get_schema(_):\n", "def get_schema(_):\n",
" return db.get_table_info()" " return db.get_table_info()\n"
] ]
}, },
{ {
@ -83,7 +83,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def run_query(query):\n", "def run_query(query):\n",
" return db.run(query)" " return db.run(query)\n"
] ]
}, },
{ {
@ -93,24 +93,18 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from operator import itemgetter\n",
"\n",
"from langchain.chat_models import ChatOpenAI\n", "from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnableLambda, RunnableMap\n", "from langchain.schema.runnable import RunnablePassthrough\n",
"\n", "\n",
"model = ChatOpenAI()\n", "model = ChatOpenAI()\n",
"\n", "\n",
"inputs = {\n",
" \"schema\": RunnableLambda(get_schema),\n",
" \"question\": itemgetter(\"question\")\n",
"}\n",
"sql_response = (\n", "sql_response = (\n",
" RunnableMap(inputs)\n", " RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n", " | prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n", " | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n", " | StrOutputParser()\n",
" )" " )\n"
] ]
}, },
{ {
@ -131,7 +125,7 @@
} }
], ],
"source": [ "source": [
"sql_response.invoke({\"question\": \"How many employees are there?\"})" "sql_response.invoke({\"question\": \"How many employees are there?\"})\n"
] ]
}, },
{ {
@ -147,7 +141,7 @@
"Question: {question}\n", "Question: {question}\n",
"SQL Query: {query}\n", "SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n", "SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_template(template)" "prompt_response = ChatPromptTemplate.from_template(template)\n"
] ]
}, },
{ {
@ -158,19 +152,14 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"full_chain = (\n", "full_chain = (\n",
" RunnableMap({\n", " RunnablePassthrough.assign(query=sql_response) \n",
" \"question\": itemgetter(\"question\"),\n", " | RunnablePassthrough.assign(\n",
" \"query\": sql_response,\n", " schema=get_schema,\n",
" }) \n", " response=lambda x: db.run(x[\"query\"]),\n",
" | {\n", " )\n",
" \"schema\": RunnableLambda(get_schema),\n",
" \"question\": itemgetter(\"question\"),\n",
" \"query\": itemgetter(\"query\"),\n",
" \"response\": lambda x: db.run(x[\"query\"]) \n",
" } \n",
" | prompt_response \n", " | prompt_response \n",
" | model\n", " | model\n",
")" ")\n"
] ]
}, },
{ {
@ -191,7 +180,7 @@
} }
], ],
"source": [ "source": [
"full_chain.invoke({\"question\": \"How many employees are there?\"})" "full_chain.invoke({\"question\": \"How many employees are there?\"})\n"
] ]
}, },
{ {

View File

@ -5,9 +5,9 @@
"id": "b022ab74-794d-4c54-ad47-ff9549ddb9d2", "id": "b022ab74-794d-4c54-ad47-ff9549ddb9d2",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Use RunnableMaps\n", "# Use RunnableParallel/RunnableMap\n",
"\n", "\n",
"RunnableMaps make it easy to execute multiple Runnables in parallel, and to return the output of these Runnables as a map." "RunnableParallel (aka. RunnableMap) makes it easy to execute multiple Runnables in parallel, and to return the output of these Runnables as a map."
] ]
}, },
{ {
@ -31,16 +31,16 @@
"source": [ "source": [
"from langchain.chat_models import ChatOpenAI\n", "from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate\n", "from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.runnable import RunnableMap\n", "from langchain.schema.runnable import RunnableParallel\n",
"\n", "\n",
"\n", "\n",
"model = ChatOpenAI()\n", "model = ChatOpenAI()\n",
"joke_chain = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n", "joke_chain = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
"poem_chain = ChatPromptTemplate.from_template(\"write a 2-line poem about {topic}\") | model\n", "poem_chain = ChatPromptTemplate.from_template(\"write a 2-line poem about {topic}\") | model\n",
"\n", "\n",
"map_chain = RunnableMap({\"joke\": joke_chain, \"poem\": poem_chain,})\n", "map_chain = RunnableParallel(joke=joke_chain, poem=poem_chain)\n",
"\n", "\n",
"map_chain.invoke({\"topic\": \"bear\"})" "map_chain.invoke({\"topic\": \"bear\"})\n"
] ]
}, },
{ {
@ -91,7 +91,7 @@
" | StrOutputParser()\n", " | StrOutputParser()\n",
")\n", ")\n",
"\n", "\n",
"retrieval_chain.invoke(\"where did harrison work?\")" "retrieval_chain.invoke(\"where did harrison work?\")\n"
] ]
}, },
{ {
@ -131,7 +131,7 @@
"source": [ "source": [
"%%timeit\n", "%%timeit\n",
"\n", "\n",
"joke_chain.invoke({\"topic\": \"bear\"})" "joke_chain.invoke({\"topic\": \"bear\"})\n"
] ]
}, },
{ {
@ -151,7 +151,7 @@
"source": [ "source": [
"%%timeit\n", "%%timeit\n",
"\n", "\n",
"poem_chain.invoke({\"topic\": \"bear\"})" "poem_chain.invoke({\"topic\": \"bear\"})\n"
] ]
}, },
{ {
@ -171,7 +171,7 @@
"source": [ "source": [
"%%timeit\n", "%%timeit\n",
"\n", "\n",
"map_chain.invoke({\"topic\": \"bear\"})" "map_chain.invoke({\"topic\": \"bear\"})\n"
] ]
} }
], ],

View File

@ -131,7 +131,7 @@
], ],
"source": [ "source": [
"# The input schema of the chain is the input schema of its first part, the prompt.\n", "# The input schema of the chain is the input schema of its first part, the prompt.\n",
"chain.input_schema.schema()" "chain.input_schema.schema()\n"
] ]
}, },
{ {
@ -244,7 +244,7 @@
], ],
"source": [ "source": [
"# The output schema of the chain is the output schema of its last part, in this case a ChatModel, which outputs a ChatMessage\n", "# The output schema of the chain is the output schema of its last part, in this case a ChatModel, which outputs a ChatMessage\n",
"chain.output_schema.schema()" "chain.output_schema.schema()\n"
] ]
}, },
{ {
@ -783,7 +783,7 @@
], ],
"source": [ "source": [
"async for chunk in retrieval_chain.astream_log(\"where did harrison work?\", include_names=['Docs'], diff=False):\n", "async for chunk in retrieval_chain.astream_log(\"where did harrison work?\", include_names=['Docs'], diff=False):\n",
" print(chunk)" " print(chunk)\n"
] ]
}, },
{ {
@ -793,7 +793,7 @@
"source": [ "source": [
"## Parallelism\n", "## Parallelism\n",
"\n", "\n",
"Let's take a look at how LangChain Expression Language support parallel requests as much as possible. For example, when using a RunnableMap (often written as a dictionary) it executes each element in parallel." "Let's take a look at how LangChain Expression Language support parallel requests as much as possible. For example, when using a RunnableParallel (often written as a dictionary) it executes each element in parallel."
] ]
}, },
{ {
@ -803,13 +803,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.schema.runnable import RunnableMap\n", "from langchain.schema.runnable import RunnableParallel\n",
"chain1 = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n", "chain1 = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
"chain2 = ChatPromptTemplate.from_template(\"write a short (2 line) poem about {topic}\") | model\n", "chain2 = ChatPromptTemplate.from_template(\"write a short (2 line) poem about {topic}\") | model\n",
"combined = RunnableMap({\n", "combined = RunnableParallel(joke=chain1, poem=chain2)\n"
" \"joke\": chain1,\n",
" \"poem\": chain2,\n",
"})\n"
] ]
}, },
{ {

View File

@ -27,7 +27,7 @@
"source": [ "source": [
"from langchain.chat_models.fireworks import ChatFireworks\n", "from langchain.chat_models.fireworks import ChatFireworks\n",
"from langchain.schema import SystemMessage, HumanMessage\n", "from langchain.schema import SystemMessage, HumanMessage\n",
"import os" "import os\n"
] ]
}, },
{ {
@ -56,7 +56,7 @@
" os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n", " os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
"\n", "\n",
"# Initialize a Fireworks chat model\n", "# Initialize a Fireworks chat model\n",
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")" "chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")\n"
] ]
}, },
{ {
@ -116,7 +116,7 @@
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n", "chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n",
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n", "system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
"human_message = HumanMessage(content=\"How's the weather today?\")\n", "human_message = HumanMessage(content=\"How's the weather today?\")\n",
"chat([system_message, human_message])" "chat([system_message, human_message])\n"
] ]
}, },
{ {
@ -144,7 +144,7 @@
"source": [ "source": [
"from langchain.chat_models import ChatFireworks\n", "from langchain.chat_models import ChatFireworks\n",
"from langchain.memory import ConversationBufferMemory\n", "from langchain.memory import ConversationBufferMemory\n",
"from langchain.schema.runnable import RunnableMap\n", "from langchain.schema.runnable import RunnablePassthrough\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n", "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"\n", "\n",
"llm = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":0, \"max_tokens\":64, \"top_p\":1.0})\n", "llm = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":0, \"max_tokens\":64, \"top_p\":1.0})\n",
@ -152,7 +152,7 @@
" (\"system\", \"You are a helpful chatbot that speaks like a pirate.\"),\n", " (\"system\", \"You are a helpful chatbot that speaks like a pirate.\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n", " MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", \"{input}\")\n", " (\"human\", \"{input}\")\n",
"])" "])\n"
] ]
}, },
{ {
@ -182,7 +182,7 @@
], ],
"source": [ "source": [
"memory = ConversationBufferMemory(return_messages=True)\n", "memory = ConversationBufferMemory(return_messages=True)\n",
"memory.load_memory_variables({})" "memory.load_memory_variables({})\n"
] ]
}, },
{ {
@ -200,13 +200,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"chain = RunnableMap({\n", "chain = RunnablePassthrough.assign(\n",
" \"input\": lambda x: x[\"input\"],\n", " history=memory.load_memory_variables | (lambda x: x[\"history\"])\n",
" \"memory\": memory.load_memory_variables\n", ") | prompt | llm.bind(stop=[\"\\n\\n\"])\n"
"}) | {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"history\": lambda x: x[\"memory\"][\"history\"]\n",
"} | prompt | llm.bind(stop=[\"\\n\\n\"])"
] ]
}, },
{ {
@ -237,7 +233,7 @@
"source": [ "source": [
"inputs = {\"input\": \"hi im bob\"}\n", "inputs = {\"input\": \"hi im bob\"}\n",
"response = chain.invoke(inputs)\n", "response = chain.invoke(inputs)\n",
"response" "response\n"
] ]
}, },
{ {
@ -268,7 +264,7 @@
], ],
"source": [ "source": [
"memory.save_context(inputs, {\"output\": response.content})\n", "memory.save_context(inputs, {\"output\": response.content})\n",
"memory.load_memory_variables({})" "memory.load_memory_variables({})\n"
] ]
}, },
{ {
@ -298,7 +294,7 @@
], ],
"source": [ "source": [
"inputs = {\"input\": \"whats my name\"}\n", "inputs = {\"input\": \"whats my name\"}\n",
"chain.invoke(inputs)" "chain.invoke(inputs)\n"
] ]
} }
], ],

View File

@ -19,7 +19,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# install the opaqueprompts and langchain packages\n", "# install the opaqueprompts and langchain packages\n",
"! pip install opaqueprompts langchain" "! pip install opaqueprompts langchain\n"
] ]
}, },
{ {
@ -40,7 +40,7 @@
"# Set API keys\n", "# Set API keys\n",
"\n", "\n",
"os.environ['OPAQUEPROMPTS_API_KEY'] = \"<OPAQUEPROMPTS_API_KEY>\"\n", "os.environ['OPAQUEPROMPTS_API_KEY'] = \"<OPAQUEPROMPTS_API_KEY>\"\n",
"os.environ['OPENAI_API_KEY'] = \"<OPENAI_API_KEY>\"" "os.environ['OPENAI_API_KEY'] = \"<OPENAI_API_KEY>\"\n"
] ]
}, },
{ {
@ -59,7 +59,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import langchain\n", "import langchain\n",
"from langchain.chains import LLMChain\nfrom langchain.prompts import PromptTemplate\n", "from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.callbacks.stdout import StdOutCallbackHandler\n", "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
"from langchain.llms import OpenAI\n", "from langchain.llms import OpenAI\n",
"from langchain.memory import ConversationBufferWindowMemory\n", "from langchain.memory import ConversationBufferWindowMemory\n",
@ -117,7 +118,7 @@
" {\"question\": \"\"\"Write a message to remind John to do password reset for his website to stay secure.\"\"\"},\n", " {\"question\": \"\"\"Write a message to remind John to do password reset for his website to stay secure.\"\"\"},\n",
" callbacks=[StdOutCallbackHandler()],\n", " callbacks=[StdOutCallbackHandler()],\n",
" )\n", " )\n",
")" ")\n"
] ]
}, },
{ {
@ -173,7 +174,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import langchain.utilities.opaqueprompts as op\n", "import langchain.utilities.opaqueprompts as op\n",
"from langchain.schema.runnable import RunnableMap\n", "from langchain.schema.runnable import RunnablePassthrough\n",
"from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.output_parser import StrOutputParser\n",
"\n", "\n",
"\n", "\n",
@ -181,19 +182,16 @@
"llm = OpenAI()\n", "llm = OpenAI()\n",
"pg_chain = (\n", "pg_chain = (\n",
" op.sanitize\n", " op.sanitize\n",
" | RunnableMap(\n", " | RunnablePassthrough.assign(\n",
" {\n", " response=(lambda x: x[\"sanitized_input\"])\n",
" \"response\": (lambda x: x[\"sanitized_input\"])\n",
" | prompt\n", " | prompt\n",
" | llm\n", " | llm\n",
" | StrOutputParser(),\n", " | StrOutputParser(),\n",
" \"secure_context\": lambda x: x[\"secure_context\"],\n",
" }\n",
" )\n", " )\n",
" | (lambda x: op.desanitize(x[\"response\"], x[\"secure_context\"]))\n", " | (lambda x: op.desanitize(x[\"response\"], x[\"secure_context\"]))\n",
")\n", ")\n",
"\n", "\n",
"pg_chain.invoke({\"question\": \"Write a text message to remind John to do password reset for his website through his email to stay secure.\", \"history\": \"\"})" "pg_chain.invoke({\"question\": \"Write a text message to remind John to do password reset for his website through his email to stay secure.\", \"history\": \"\"})\n"
] ]
} }
], ],

View File

@ -4,7 +4,7 @@ from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import NoOpOutputParser from langchain.schema.output_parser import NoOpOutputParser
from langchain.schema.prompt_template import BasePromptTemplate from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema.runnable import RunnableMap, RunnableSequence from langchain.schema.runnable import RunnableParallel, RunnableSequence
from langchain.utilities.sql_database import SQLDatabase from langchain.utilities.sql_database import SQLDatabase
@ -60,7 +60,7 @@ def create_sql_query_chain(
if "dialect" in prompt_to_use.input_variables: if "dialect" in prompt_to_use.input_variables:
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use) inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
return ( return (
RunnableMap(inputs) RunnableParallel(inputs)
| prompt_to_use | prompt_to_use
| llm.bind(stop=["\nSQLResult:"]) | llm.bind(stop=["\nSQLResult:"])
| NoOpOutputParser() | NoOpOutputParser()

View File

@ -5,6 +5,7 @@ from langchain.schema.runnable.base import (
RunnableGenerator, RunnableGenerator,
RunnableLambda, RunnableLambda,
RunnableMap, RunnableMap,
RunnableParallel,
RunnableSequence, RunnableSequence,
RunnableSerializable, RunnableSerializable,
) )
@ -30,6 +31,7 @@ __all__ = [
"RunnableGenerator", "RunnableGenerator",
"RunnableLambda", "RunnableLambda",
"RunnableMap", "RunnableMap",
"RunnableParallel",
"RunnablePassthrough", "RunnablePassthrough",
"RunnableSequence", "RunnableSequence",
"RunnableWithFallbacks", "RunnableWithFallbacks",

View File

@ -1490,7 +1490,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
yield chunk yield chunk
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]): class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
""" """
A runnable that runs a mapping of runnables in parallel, A runnable that runs a mapping of runnables in parallel,
and returns a mapping of their outputs. and returns a mapping of their outputs.
@ -1500,16 +1500,27 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
def __init__( def __init__(
self, self,
steps: Mapping[ __steps: Optional[
Mapping[
str, str,
Union[ Union[
Runnable[Input, Any], Runnable[Input, Any],
Callable[[Input], Any], Callable[[Input], Any],
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
], ],
]
] = None,
**kwargs: Union[
Runnable[Input, Any],
Callable[[Input], Any],
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
], ],
) -> None: ) -> None:
super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()}) merged = {**__steps} if __steps is not None else {}
merged.update(kwargs)
super().__init__(
steps={key: coerce_to_runnable(r) for key, r in merged.items()}
)
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
@ -1538,7 +1549,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
): ):
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableMapInput", "RunnableParallelInput",
**{ **{
k: (v.annotation, v.default) k: (v.annotation, v.default)
for step in self.steps.values() for step in self.steps.values()
@ -1553,7 +1564,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableMapOutput", "RunnableParallelOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()}, **{k: (v.OutputType, None) for k, v in self.steps.items()},
) )
@ -1797,6 +1808,10 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
yield chunk yield chunk
# We support both names
RunnableMap = RunnableParallel
class RunnableGenerator(Runnable[Input, Output]): class RunnableGenerator(Runnable[Input, Output]):
""" """
A runnable that runs a generator function. A runnable that runs a generator function.
@ -2435,10 +2450,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
elif callable(thing): elif callable(thing):
return RunnableLambda(cast(Callable[[Input], Output], thing)) return RunnableLambda(cast(Callable[[Input], Output], thing))
elif isinstance(thing, dict): elif isinstance(thing, dict):
runnables: Mapping[str, Runnable[Any, Any]] = { return cast(Runnable[Input, Output], RunnableParallel(thing))
key: coerce_to_runnable(r) for key, r in thing.items()
}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else: else:
raise TypeError( raise TypeError(
f"Expected a Runnable, callable or dict." f"Expected a Runnable, callable or dict."

View File

@ -21,7 +21,7 @@ from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import (
Input, Input,
Runnable, Runnable,
RunnableMap, RunnableParallel,
RunnableSerializable, RunnableSerializable,
) )
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
@ -83,7 +83,7 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]):
A runnable that merges the Dict input with the output produced by the A runnable that merges the Dict input with the output produced by the
mapping argument. mapping argument.
""" """
return RunnableAssign(RunnableMap(kwargs)) return RunnableAssign(RunnableParallel(kwargs))
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config) return self._call_with_config(identity, input, config)
@ -119,9 +119,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
A runnable that assigns key-value pairs to Dict[str, Any] inputs. A runnable that assigns key-value pairs to Dict[str, Any] inputs.
""" """
mapper: RunnableMap[Dict[str, Any]] mapper: RunnableParallel[Dict[str, Any]]
def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None: def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs) super().__init__(mapper=mapper, **kwargs)
@classmethod @classmethod

View File

@ -5,7 +5,7 @@ from langchain.llms.opaqueprompts import OpaquePrompts
from langchain.memory import ConversationBufferWindowMemory from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap from langchain.schema.runnable import RunnableParallel
prompt_template = """ prompt_template = """
As an AI assistant, you will answer questions according to given context. As an AI assistant, you will answer questions according to given context.
@ -64,14 +64,12 @@ def test_opaqueprompts_functions() -> None:
llm = OpenAI() llm = OpenAI()
pg_chain = ( pg_chain = (
op.sanitize op.sanitize
| RunnableMap( | RunnableParallel(
{ secure_context=lambda x: x["secure_context"], # type: ignore
"response": (lambda x: x["sanitized_input"]) # type: ignore response=(lambda x: x["sanitized_input"]) # type: ignore
| prompt | prompt
| llm | llm
| StrOutputParser(), | StrOutputParser(),
"secure_context": lambda x: x["secure_context"],
}
) )
| (lambda x: op.desanitize(x["response"], x["secure_context"])) | (lambda x: op.desanitize(x["response"], x["secure_context"]))
) )

View File

@ -629,7 +629,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -643,7 +643,7 @@
"base", "base",
"RunnableLambda" "RunnableLambda"
], ],
"repr": "RunnableLambda(...)" "repr": "RunnableLambda(lambda x: x['key'])"
}, },
"input": { "input": {
"lc": 1, "lc": 1,
@ -652,7 +652,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -666,7 +666,7 @@
"base", "base",
"RunnableLambda" "RunnableLambda"
], ],
"repr": "RunnableLambda(...)" "repr": "RunnableLambda(lambda x: x['question'])"
} }
} }
} }
@ -709,7 +709,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -1438,7 +1438,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -1461,7 +1461,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -3455,7 +3455,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -3769,7 +3769,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {
@ -3910,7 +3910,7 @@
"langchain", "langchain",
"schema", "schema",
"runnable", "runnable",
"RunnableMap" "RunnableParallel"
], ],
"kwargs": { "kwargs": {
"steps": { "steps": {

View File

@ -53,7 +53,7 @@ from langchain.schema.runnable import (
RunnableBranch, RunnableBranch,
RunnableConfig, RunnableConfig,
RunnableLambda, RunnableLambda,
RunnableMap, RunnableParallel,
RunnablePassthrough, RunnablePassthrough,
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
@ -491,7 +491,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"properties": {"name": {"title": "Name", "type": "string"}}, "properties": {"name": {"title": "Name", "type": "string"}},
} }
assert seq_w_map.output_schema.schema() == { assert seq_w_map.output_schema.schema() == {
"title": "RunnableMapOutput", "title": "RunnableParallelOutput",
"type": "object", "type": "object",
"properties": { "properties": {
"original": {"title": "Original", "type": "string"}, "original": {"title": "Original", "type": "string"},
@ -593,7 +593,7 @@ def test_schema_complex_seq() -> None:
) )
assert chain2.input_schema.schema() == { assert chain2.input_schema.schema() == {
"title": "RunnableMapInput", "title": "RunnableParallelInput",
"type": "object", "type": "object",
"properties": { "properties": {
"person": {"title": "Person", "type": "string"}, "person": {"title": "Person", "type": "string"},
@ -1656,12 +1656,12 @@ async def test_stream_log_retriever() -> None:
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": "/logs/RunnableMap", "path": "/logs/RunnableParallel",
"value": { "value": {
"end_time": None, "end_time": None,
"final_output": None, "final_output": None,
"metadata": {}, "metadata": {},
"name": "RunnableMap", "name": "RunnableParallel",
"start_time": "2023-01-01T00:00:00.000", "start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [], "streamed_output_str": [],
"tags": ["seq:step:1"], "tags": ["seq:step:1"],
@ -1733,7 +1733,7 @@ async def test_stream_log_retriever() -> None:
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": "/logs/RunnableMap/final_output", "path": "/logs/RunnableParallel/final_output",
"value": { "value": {
"documents": [ "documents": [
Document(page_content="foo"), Document(page_content="foo"),
@ -1744,7 +1744,7 @@ async def test_stream_log_retriever() -> None:
}, },
{ {
"op": "add", "op": "add",
"path": "/logs/RunnableMap/end_time", "path": "/logs/RunnableParallel/end_time",
"value": "2023-01-01T00:00:00.000", "value": "2023-01-01T00:00:00.000",
}, },
), ),
@ -1792,8 +1792,8 @@ async def test_stream_log_retriever() -> None:
"FakeListLLM:2", "FakeListLLM:2",
"Retriever", "Retriever",
"RunnableLambda", "RunnableLambda",
"RunnableMap", "RunnableParallel",
"RunnableMap:2", "RunnableParallel:2",
] ]
@ -1977,7 +1977,7 @@ Question:
assert repr(chain) == snapshot assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
assert isinstance(chain.first, RunnableMap) assert isinstance(chain.first, RunnableParallel)
assert chain.middle == [prompt, chat] assert chain.middle == [prompt, chat]
assert chain.last == parser assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
@ -2013,7 +2013,7 @@ What is your name?"""
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4 assert len(parent_run.child_runs) == 4
map_run = parent_run.child_runs[0] map_run = parent_run.child_runs[0]
assert map_run.name == "RunnableMap" assert map_run.name == "RunnableParallel"
assert len(map_run.child_runs) == 3 assert len(map_run.child_runs) == 3
@ -2043,7 +2043,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
assert chain.first == prompt assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)] assert chain.middle == [RunnableLambda(passthrough)]
assert isinstance(chain.last, RunnableMap) assert isinstance(chain.last, RunnableParallel)
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
# Test invoke # Test invoke
@ -2074,7 +2074,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 3 assert len(parent_run.child_runs) == 3
map_run = parent_run.child_runs[2] map_run = parent_run.child_runs[2]
assert map_run.name == "RunnableMap" assert map_run.name == "RunnableParallel"
assert len(map_run.child_runs) == 2 assert len(map_run.child_runs) == 2
@ -2142,11 +2142,9 @@ async def test_higher_order_lambda_runnable(
english_chain = ChatPromptTemplate.from_template( english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}" "You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"]) ) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableMap( input_map: Runnable = RunnableParallel(
{ # type: ignore[arg-type] key=lambda x: x["key"],
"key": lambda x: x["key"], input={"question": lambda x: x["question"]},
"input": {"question": lambda x: x["question"]},
}
) )
def router(input: Dict[str, Any]) -> Runnable: def router(input: Dict[str, Any]) -> Runnable:
@ -2158,6 +2156,7 @@ async def test_higher_order_lambda_runnable(
raise ValueError(f"Unknown key: {input['key']}") raise ValueError(f"Unknown key: {input['key']}")
chain: Runnable = input_map | router chain: Runnable = input_map | router
if sys.version_info >= (3, 9):
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
result = chain.invoke({"key": "math", "question": "2 + 2"}) result = chain.invoke({"key": "math", "question": "2 + 2"})
@ -2256,7 +2255,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
assert chain.first == prompt assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)] assert chain.middle == [RunnableLambda(passthrough)]
assert isinstance(chain.last, RunnableMap) assert isinstance(chain.last, RunnableParallel)
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
# Test invoke # Test invoke
@ -2293,7 +2292,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 3 assert len(parent_run.child_runs) == 3
map_run = parent_run.child_runs[2] map_run = parent_run.child_runs[2]
assert map_run.name == "RunnableMap" assert map_run.name == "RunnableParallel"
assert len(map_run.child_runs) == 3 assert len(map_run.child_runs) == 3
@ -2460,12 +2459,12 @@ async def test_map_astream() -> None:
assert final_state.state["logs"]["ChatPromptTemplate"][ assert final_state.state["logs"]["ChatPromptTemplate"][
"final_output" "final_output"
] == prompt.invoke({"question": "What is your name?"}) ] == prompt.invoke({"question": "What is your name?"})
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap" assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
assert sorted(final_state.state["logs"]) == [ assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate", "ChatPromptTemplate",
"FakeListChatModel", "FakeListChatModel",
"FakeStreamingListLLM", "FakeStreamingListLLM",
"RunnableMap", "RunnableParallel",
"RunnablePassthrough", "RunnablePassthrough",
] ]
@ -2505,11 +2504,11 @@ async def test_map_astream() -> None:
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == ( assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
prompt.invoke({"question": "What is your name?"}) prompt.invoke({"question": "What is your name?"})
) )
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap" assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
assert sorted(final_state.state["logs"]) == [ assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate", "ChatPromptTemplate",
"FakeStreamingListLLM", "FakeStreamingListLLM",
"RunnableMap", "RunnableParallel",
"RunnablePassthrough", "RunnablePassthrough",
] ]
@ -2910,7 +2909,7 @@ def llm_chain_with_fallbacks() -> RunnableSequence:
pass_llm = FakeListLLM(responses=["bar"]) pass_llm = FakeListLLM(responses=["bar"])
prompt = PromptTemplate.from_template("what did baz say to {buz}") prompt = PromptTemplate.from_template("what did baz say to {buz}")
return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks( return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
[prompt | pass_llm] [prompt | pass_llm]
) )