router runnable (#8496)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Harrison Chase
2023-07-31 11:07:10 -07:00
committed by GitHub
parent 913a156cff
commit 5e3b968078
4 changed files with 512 additions and 3 deletions

View File

@@ -22,10 +22,19 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"id": "466b65b3",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.chat_models import ChatOpenAI"
@@ -33,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"id": "3c634ef0",
"metadata": {},
"outputs": [],
@@ -583,6 +592,98 @@
"chain2.invoke({})"
]
},
{
"cell_type": "markdown",
"id": "d094d637",
"metadata": {},
"source": [
"## Router\n",
"\n",
"You can also use the router runnable to conditionally route inputs to different runnables."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "252625fd",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import create_tagging_chain_pydantic\n",
"from pydantic import BaseModel, Field\n",
"\n",
"class PromptToUse(BaseModel):\n",
" \"\"\"Used to determine which prompt to use to answer the user's input.\"\"\"\n",
" \n",
" name: str = Field(description=\"Should be one of `math` or `english`\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "57886e84",
"metadata": {},
"outputs": [],
"source": [
"tagger = create_tagging_chain_pydantic(PromptToUse, ChatOpenAI(temperature=0))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a303b089",
"metadata": {},
"outputs": [],
"source": [
"chain1 = ChatPromptTemplate.from_template(\"You are a math genius. Answer the question: {question}\") | ChatOpenAI()\n",
"chain2 = ChatPromptTemplate.from_template(\"You are an english major. Answer the question: {question}\") | ChatOpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7aa9ea06",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RouterRunnable\n",
"router = RouterRunnable({\"math\": chain1, \"english\": chain2})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6a3d3f5d",
"metadata": {},
"outputs": [],
"source": [
"chain = {\n",
" \"key\": {\"input\": lambda x: x[\"question\"]} | tagger | (lambda x: x['text'].name),\n",
" \"input\": {\"question\": lambda x: x[\"question\"]}\n",
"} | router"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8aeda930",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Thank you for the compliment! The sum of 2 + 2 is equal to 4.', additional_kwargs={}, example=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"question\": \"whats 2 + 2\"})"
]
},
{
"cell_type": "markdown",
"id": "29781123",