mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
Fetch runnable config from context var inside runnable lambda and runnable generator (#15334)
- easier to write custom logic/loops with automatic tracing - if you don't want to streaming support write a regular function and pass to RunnableLambda - if you do want streaming write a generator and pass it to RunnableGenerator ```py import json from typing import AsyncIterator from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage from langchain_core.agents import AgentAction, AgentFinish from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable, RunnableGenerator, RunnablePassthrough from langchain_core.tools import BaseTool from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.chat_models import ChatOpenAI from langchain.tools.render import format_tool_to_openai_function def _get_tavily(): from langchain.tools.tavily_search import TavilySearchResults from langchain.utilities.tavily_search import TavilySearchAPIWrapper tavily_search = TavilySearchAPIWrapper() return TavilySearchResults(api_wrapper=tavily_search) async def _agent_executor_generator( input: AsyncIterator[list[BaseMessage]], *, max_iterations: int = 10, tools: dict[str, BaseTool], agent: Runnable[list[BaseMessage], BaseMessage], parser: Runnable[BaseMessage, AgentAction | AgentFinish], ) -> AsyncIterator[BaseMessage]: messages = [m async for mm in input for m in mm] for _ in range(max_iterations): next_message = await agent.ainvoke(messages) yield next_message messages.append(next_message) parsed = await parser.ainvoke(next_message) if isinstance(parsed, AgentAction): result = await tools[parsed.tool].ainvoke(parsed.tool_input) next_message = FunctionMessage(name=parsed.tool, content=json.dumps(result)) yield next_message messages.append(next_message) elif isinstance(parsed, AgentFinish): return def get_agent_executor(tools: list[BaseTool], system_message: str): llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0, streaming=True) prompt = ChatPromptTemplate.from_messages( [ ("system", system_message), MessagesPlaceholder(variable_name="messages"), ] ) llm_with_tools = llm.bind( functions=[format_tool_to_openai_function(t) for t in tools] ) agent = {"messages": RunnablePassthrough()} | prompt | llm_with_tools parser = OpenAIFunctionsAgentOutputParser() executor = RunnableGenerator(_agent_executor_generator) return executor.bind( tools={tool.name for tool in tools}, agent=agent, parser=parser ) agent = get_agent_executor([_get_tavily()], "You are a very nice agent!") async def main(): async for message in agent.astream( [HumanMessage(content="whats the weather in sf tomorrow?")] ): print(message) if __name__ == "__main__": import asyncio asyncio.run(main()) ``` results in this trace https://smith.langchain.com/public/fa17f05d-9724-4d08-8fa1-750f8fcd051b/r
This commit is contained in:
@@ -12,7 +12,7 @@ from langchain_core.runnables.utils import aadd, add
|
||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||
|
||||
|
||||
class TestCase(NamedTuple):
|
||||
class _TestCase(NamedTuple):
|
||||
input: Any
|
||||
output: Any
|
||||
|
||||
@@ -102,22 +102,22 @@ test_cases = [
|
||||
(
|
||||
Context.setter("foo") | Context.getter("foo"),
|
||||
(
|
||||
TestCase("foo", "foo"),
|
||||
TestCase("bar", "bar"),
|
||||
_TestCase("foo", "foo"),
|
||||
_TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
Context.setter("input") | {"bar": Context.getter("input")},
|
||||
(
|
||||
TestCase("foo", {"bar": "foo"}),
|
||||
TestCase("bar", {"bar": "bar"}),
|
||||
_TestCase("foo", {"bar": "foo"}),
|
||||
_TestCase("bar", {"bar": "bar"}),
|
||||
),
|
||||
),
|
||||
(
|
||||
{"bar": Context.setter("input")} | Context.getter("input"),
|
||||
(
|
||||
TestCase("foo", "foo"),
|
||||
TestCase("bar", "bar"),
|
||||
_TestCase("foo", "foo"),
|
||||
_TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
@@ -132,11 +132,11 @@ test_cases = [
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
@@ -155,7 +155,7 @@ test_cases = [
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{
|
||||
"response": "hello",
|
||||
@@ -163,7 +163,7 @@ test_cases = [
|
||||
"prompt_str": "foo bar",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{
|
||||
"response": "hello",
|
||||
@@ -185,11 +185,11 @@ test_cases = [
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
@@ -207,11 +207,11 @@ test_cases = [
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
@@ -229,11 +229,11 @@ test_cases = [
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
@@ -242,7 +242,7 @@ test_cases = [
|
||||
(
|
||||
seq_naive_rag,
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
@@ -254,7 +254,7 @@ test_cases = [
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
@@ -271,7 +271,7 @@ test_cases = [
|
||||
(
|
||||
seq_naive_rag_alt,
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
@@ -283,7 +283,7 @@ test_cases = [
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
@@ -300,7 +300,7 @@ test_cases = [
|
||||
(
|
||||
seq_naive_rag_scoped,
|
||||
(
|
||||
TestCase(
|
||||
_TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
@@ -312,7 +312,7 @@ test_cases = [
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
_TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
@@ -331,7 +331,7 @@ test_cases = [
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
async def test_context_runnables(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase]
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[_TestCase]
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert runnable.invoke(cases[0].input) == cases[0].output
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from langchain_core.runnables import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"chain",
|
||||
"AddableDict",
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
|
@@ -68,6 +68,7 @@ from langchain_core.runnables import (
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
add,
|
||||
chain,
|
||||
)
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.tracers import (
|
||||
@@ -4388,9 +4389,9 @@ async def test_runnable_gen() -> None:
|
||||
|
||||
runnable = RunnableGenerator(gen)
|
||||
|
||||
assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"}
|
||||
assert runnable.input_schema.schema() == {"title": "gen_input"}
|
||||
assert runnable.output_schema.schema() == {
|
||||
"title": "RunnableGeneratorOutput",
|
||||
"title": "gen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
@@ -4410,6 +4411,315 @@ async def test_runnable_gen() -> None:
|
||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||
|
||||
|
||||
async def test_runnable_gen_context_config() -> None:
|
||||
"""Test that a generator can call other runnables with config
|
||||
propagated from the context."""
|
||||
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
def gen(input: Iterator[Any]) -> Iterator[int]:
|
||||
yield fake.invoke("a")
|
||||
yield fake.invoke("aa")
|
||||
yield fake.invoke("aaa")
|
||||
|
||||
runnable = RunnableGenerator(gen)
|
||||
|
||||
assert runnable.input_schema.schema() == {"title": "gen_input"}
|
||||
assert runnable.output_schema.schema() == {
|
||||
"title": "gen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
tracer.runs.clear()
|
||||
|
||||
assert list(runnable.stream(None)) == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
assert len(tracer.runs[1].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
# Python 3.10 and below don't support running async tasks in a specific context
|
||||
return
|
||||
|
||||
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
yield await fake.ainvoke("a")
|
||||
yield await fake.ainvoke("aa")
|
||||
yield await fake.ainvoke("aaa")
|
||||
|
||||
arunnable = RunnableGenerator(agen)
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
tracer.runs.clear()
|
||||
|
||||
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
assert len(tracer.runs[1].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_runnable_iter_context_config() -> None:
|
||||
"""Test that a generator can call other runnables with config
|
||||
propagated from the context."""
|
||||
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
@chain
|
||||
def gen(input: str) -> Iterator[int]:
|
||||
yield fake.invoke(input)
|
||||
yield fake.invoke(input * 2)
|
||||
yield fake.invoke(input * 3)
|
||||
|
||||
assert gen.input_schema.schema() == {
|
||||
"title": "gen_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert gen.output_schema.schema() == {
|
||||
"title": "gen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert gen.invoke("a", {"callbacks": [tracer]}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
tracer.runs.clear()
|
||||
|
||||
assert list(gen.stream("a")) == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert list(gen.stream("a", {"callbacks": [tracer]})) == [1, 2, 3]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert gen.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
assert len(tracer.runs[1].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
# Python 3.10 and below don't support running async tasks in a specific context
|
||||
return
|
||||
|
||||
@chain
|
||||
async def agen(input: str) -> AsyncIterator[int]:
|
||||
yield await fake.ainvoke(input)
|
||||
yield await fake.ainvoke(input * 2)
|
||||
yield await fake.ainvoke(input * 3)
|
||||
|
||||
assert agen.input_schema.schema() == {
|
||||
"title": "agen_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert agen.output_schema.schema() == {
|
||||
"title": "agen_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await agen.ainvoke("a", {"callbacks": [tracer]}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
tracer.runs.clear()
|
||||
|
||||
assert [p async for p in agen.astream("a")] == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert [p async for p in agen.astream("a", {"callbacks": [tracer]})] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
assert len(tracer.runs[1].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_runnable_lambda_context_config() -> None:
|
||||
"""Test that a function can call other runnables with config
|
||||
propagated from the context."""
|
||||
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
@chain
|
||||
def fun(input: str) -> int:
|
||||
output = fake.invoke(input)
|
||||
output += fake.invoke(input * 2)
|
||||
output += fake.invoke(input * 3)
|
||||
return output
|
||||
|
||||
assert fun.input_schema.schema() == {"title": "fun_input", "type": "string"}
|
||||
assert fun.output_schema.schema() == {
|
||||
"title": "fun_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert fun.invoke("a", {"callbacks": [tracer]}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
tracer.runs.clear()
|
||||
|
||||
assert list(fun.stream("a")) == [6]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert list(fun.stream("a", {"callbacks": [tracer]})) == [6]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert fun.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
assert len(tracer.runs[1].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
# Python 3.10 and below don't support running async tasks in a specific context
|
||||
return
|
||||
|
||||
@chain
|
||||
async def afun(input: str) -> int:
|
||||
output = await fake.ainvoke(input)
|
||||
output += await fake.ainvoke(input * 2)
|
||||
output += await fake.ainvoke(input * 3)
|
||||
return output
|
||||
|
||||
assert afun.input_schema.schema() == {"title": "afun_input", "type": "string"}
|
||||
assert afun.output_schema.schema() == {
|
||||
"title": "afun_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await afun.ainvoke("a", {"callbacks": [tracer]}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
tracer.runs.clear()
|
||||
|
||||
assert [p async for p in afun.astream("a")] == [6]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert [p async for p in afun.astream("a", {"callbacks": [tracer]})] == [6]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await afun.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
assert len(tracer.runs[1].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_runnable_gen_transform() -> None:
|
||||
"""Test that a generator can be used as a runnable."""
|
||||
|
||||
@@ -4434,19 +4744,19 @@ async def test_runnable_gen_transform() -> None:
|
||||
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
|
||||
|
||||
assert chain.input_schema.schema() == {
|
||||
"title": "RunnableGeneratorInput",
|
||||
"title": "gen_indexes_input",
|
||||
"type": "integer",
|
||||
}
|
||||
assert chain.output_schema.schema() == {
|
||||
"title": "RunnableGeneratorOutput",
|
||||
"title": "plus_one_output",
|
||||
"type": "integer",
|
||||
}
|
||||
assert achain.input_schema.schema() == {
|
||||
"title": "RunnableGeneratorInput",
|
||||
"title": "gen_indexes_input",
|
||||
"type": "integer",
|
||||
}
|
||||
assert achain.output_schema.schema() == {
|
||||
"title": "RunnableGeneratorOutput",
|
||||
"title": "aplus_one_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user