mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
router runnable (#8496)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
913a156cff
commit
5e3b968078
@ -22,10 +22,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 1,
|
||||||
"id": "466b65b3",
|
"id": "466b65b3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.prompts import ChatPromptTemplate\n",
|
"from langchain.prompts import ChatPromptTemplate\n",
|
||||||
"from langchain.chat_models import ChatOpenAI"
|
"from langchain.chat_models import ChatOpenAI"
|
||||||
@ -33,7 +42,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 2,
|
||||||
"id": "3c634ef0",
|
"id": "3c634ef0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -583,6 +592,98 @@
|
|||||||
"chain2.invoke({})"
|
"chain2.invoke({})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d094d637",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Router\n",
|
||||||
|
"\n",
|
||||||
|
"You can also use the router runnable to conditionally route inputs to different runnables."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "252625fd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import create_tagging_chain_pydantic\n",
|
||||||
|
"from pydantic import BaseModel, Field\n",
|
||||||
|
"\n",
|
||||||
|
"class PromptToUse(BaseModel):\n",
|
||||||
|
" \"\"\"Used to determine which prompt to use to answer the user's input.\"\"\"\n",
|
||||||
|
" \n",
|
||||||
|
" name: str = Field(description=\"Should be one of `math` or `english`\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "57886e84",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tagger = create_tagging_chain_pydantic(PromptToUse, ChatOpenAI(temperature=0))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "a303b089",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain1 = ChatPromptTemplate.from_template(\"You are a math genius. Answer the question: {question}\") | ChatOpenAI()\n",
|
||||||
|
"chain2 = ChatPromptTemplate.from_template(\"You are an english major. Answer the question: {question}\") | ChatOpenAI()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "7aa9ea06",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema.runnable import RouterRunnable\n",
|
||||||
|
"router = RouterRunnable({\"math\": chain1, \"english\": chain2})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "6a3d3f5d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = {\n",
|
||||||
|
" \"key\": {\"input\": lambda x: x[\"question\"]} | tagger | (lambda x: x['text'].name),\n",
|
||||||
|
" \"input\": {\"question\": lambda x: x[\"question\"]}\n",
|
||||||
|
"} | router"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "8aeda930",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='Thank you for the compliment! The sum of 2 + 2 is equal to 4.', additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.invoke({\"question\": \"whats 2 + 2\"})"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "29781123",
|
"id": "29781123",
|
||||||
|
@ -108,6 +108,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
# If there's only one input, don't bother with the executor
|
||||||
|
if len(inputs) == 1:
|
||||||
|
return [self.invoke(inputs[0], configs[0])]
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
return list(executor.map(self.invoke, inputs, configs))
|
return list(executor.map(self.invoke, inputs, configs))
|
||||||
|
|
||||||
@ -759,6 +763,140 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
class RouterInput(TypedDict):
|
||||||
|
key: str
|
||||||
|
input: Any
|
||||||
|
|
||||||
|
|
||||||
|
class RouterRunnable(
|
||||||
|
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||||
|
):
|
||||||
|
runnables: Mapping[str, Runnable[Input, Output]]
|
||||||
|
|
||||||
|
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||||
|
super().__init__(runnables=runnables)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Any, Other],
|
||||||
|
Callable[[Any], Other],
|
||||||
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||||
|
Mapping[str, Any],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[RouterInput, Other]:
|
||||||
|
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||||
|
|
||||||
|
def __ror__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Other, Any],
|
||||||
|
Callable[[Any], Other],
|
||||||
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||||
|
Mapping[str, Any],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[Other, Output]:
|
||||||
|
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
return runnable.invoke(actual_input, config)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
return await runnable.ainvoke(actual_input, config)
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[RouterInput],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
keys = [input["key"] for input in inputs]
|
||||||
|
actual_inputs = [input["input"] for input in inputs]
|
||||||
|
if any(key not in self.runnables for key in keys):
|
||||||
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
|
runnables = [self.runnables[key] for key in keys]
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
|
return list(
|
||||||
|
executor.map(
|
||||||
|
lambda runnable, input, config: runnable.invoke(input, config),
|
||||||
|
runnables,
|
||||||
|
actual_inputs,
|
||||||
|
configs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[RouterInput],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
keys = [input["key"] for input in inputs]
|
||||||
|
actual_inputs = [input["input"] for input in inputs]
|
||||||
|
if any(key not in self.runnables for key in keys):
|
||||||
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
|
runnables = [self.runnables[key] for key in keys]
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
return await _gather_with_concurrency(
|
||||||
|
max_concurrency,
|
||||||
|
*(
|
||||||
|
runnable.ainvoke(input, config)
|
||||||
|
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
yield from runnable.stream(actual_input, config)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
async for output in runnable.astream(actual_input, config):
|
||||||
|
yield output
|
||||||
|
|
||||||
|
|
||||||
def _patch_config(
|
def _patch_config(
|
||||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -23,6 +23,7 @@ from langchain.schema.document import Document
|
|||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
from langchain.schema.retriever import BaseRetriever
|
from langchain.schema.retriever import BaseRetriever
|
||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
|
RouterRunnable,
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
@ -572,6 +573,54 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
assert len(map_run.child_runs) == 2
|
assert len(map_run.child_runs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_router_runnable(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
chain1 = ChatPromptTemplate.from_template(
|
||||||
|
"You are a math genius. Answer the question: {question}"
|
||||||
|
) | FakeListLLM(responses=["4"])
|
||||||
|
chain2 = ChatPromptTemplate.from_template(
|
||||||
|
"You are an english major. Answer the question: {question}"
|
||||||
|
) | FakeListLLM(responses=["2"])
|
||||||
|
router = RouterRunnable({"math": chain1, "english": chain2})
|
||||||
|
chain: Runnable = {
|
||||||
|
"key": lambda x: x["key"],
|
||||||
|
"input": {"question": lambda x: x["question"]},
|
||||||
|
} | router
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||||
|
assert result == "4"
|
||||||
|
|
||||||
|
result2 = chain.batch(
|
||||||
|
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||||
|
)
|
||||||
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
|
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||||
|
assert result == "4"
|
||||||
|
|
||||||
|
result2 = await chain.abatch(
|
||||||
|
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||||
|
)
|
||||||
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
router_spy = mocker.spy(router.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||||
|
== "4"
|
||||||
|
)
|
||||||
|
assert router_spy.call_args.args[1] == {
|
||||||
|
"key": "math",
|
||||||
|
"input": {"question": "2 + 2"},
|
||||||
|
}
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||||
|
Loading…
Reference in New Issue
Block a user