Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
53cb703f71 x 2024-05-14 11:44:11 -04:00

View File

@@ -362,6 +362,10 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
]
async def test_foo() -> None:
""""""
async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
@@ -397,6 +401,116 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
]
async def test_foo() -> None:
"""Test the output of a model."""
infinite_cycle = cycle([AIMessage(content="hello world!")])
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
model = (
GenericFakeChatModel(messages=infinite_cycle)
.with_config(
{
"metadata": {"a": "b"},
"tags": ["my_model"],
"run_name": "my_model",
}
)
.bind(stop="<stop_token>")
)
@RunnableLambda
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return model.invoke(input)
else:
return model.invoke(input, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chain_start",
"metadata": {},
"name": "i_dont_stream",
"run_id": "",
"tags": [],
},
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": {
"generations": [
[
{
"generation_info": None,
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"text": "hello world!",
"type": "ChatGeneration",
}
]
],
"llm_output": None,
"run": None,
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
"run_id": "",
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
"run_id": "",
"tags": [],
},
]
async def test_astream_events_from_model() -> None:
"""Test the output of a model."""
infinite_cycle = cycle([AIMessage(content="hello world!")])