core: autodetect more ls params (#25044)

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Erick Friis 2024-08-08 12:44:21 -07:00 committed by GitHub
parent 86355640c3
commit c6ece6a96d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 3133 additions and 2804 deletions

View File

@ -522,9 +522,37 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
ls_params = LangSmithParams(ls_model_type="chat")
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.startswith("Chat"):
default_provider = default_provider[4:].lower()
elif default_provider.endswith("Chat"):
default_provider = default_provider[:-4]
default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
if stop:
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name
# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature
# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens
return ls_params
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:

File diff suppressed because one or more lines are too long

View File

@ -57,6 +57,26 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEven
return events_
def _assert_events_equal_allow_superset_metadata(events: List, expected: List) -> None:
"""Assert that the events are equal."""
assert len(events) == len(expected)
for i, (event, expected_event) in enumerate(zip(events, expected)):
# we want to allow a superset of metadata on each
event_with_edited_metadata = {
k: (
v
if k != "metadata"
else {
metadata_k: metadata_v
for metadata_k, metadata_v in v.items()
if metadata_k in expected_event["metadata"]
}
)
for k, v in event.items()
}
assert event_with_edited_metadata == expected_event, f"Event {i} did not match."
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool"""
@ -71,7 +91,9 @@ async def test_event_stream_with_simple_function_tool() -> None:
chain = RunnableLambda(foo) | get_docs
events = await _collect_events(chain.astream_events({}, version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"event": "on_chain_start",
"run_id": "",
@ -144,7 +166,8 @@ async def test_event_stream_with_simple_function_tool() -> None:
"metadata": {},
"data": {"output": [Document(page_content="hello")]},
},
]
],
)
async def test_event_stream_with_single_lambda() -> None:
@ -157,7 +180,9 @@ async def test_event_stream_with_single_lambda() -> None:
chain = RunnableLambda(func=reverse)
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -185,7 +210,8 @@ async def test_event_stream_with_single_lambda() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_triple_lambda() -> None:
@ -201,7 +227,9 @@ async def test_event_stream_with_triple_lambda() -> None:
| r.with_config({"run_name": "3"})
)
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -310,7 +338,8 @@ async def test_event_stream_with_triple_lambda() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_triple_lambda_test_filtering() -> None:
@ -330,7 +359,9 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
events = await _collect_events(
chain.astream_events("hello", include_names=["1"], version="v1")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {},
"event": "on_chain_start",
@ -358,14 +389,17 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"parent_ids": [],
"tags": ["seq:step:1"],
},
]
],
)
events = await _collect_events(
chain.astream_events(
"hello", include_tags=["my_tag"], exclude_names=["2"], version="v1"
)
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {},
"event": "on_chain_start",
@ -393,7 +427,8 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"parent_ids": [],
"tags": ["my_tag", "seq:step:3"],
},
]
],
)
async def test_event_stream_with_lambdas_from_lambda() -> None:
@ -403,7 +438,9 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
events = await _collect_events(
as_lambdas.astream_events({"question": "hello"}, version="v1")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
@ -431,7 +468,8 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_astream_events_from_model() -> None:
@ -450,7 +488,9 @@ async def test_astream_events_from_model() -> None:
.bind(stop="<stop_token>")
)
events = await _collect_events(model.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
@ -496,7 +536,8 @@ async def test_astream_events_from_model() -> None:
"parent_ids": [],
"tags": ["my_model"],
},
]
],
)
@RunnableLambda
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
@ -506,7 +547,9 @@ async def test_astream_events_from_model() -> None:
return model.invoke(input, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -519,7 +562,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -528,7 +575,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -537,7 +588,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -546,7 +601,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -571,7 +630,11 @@ async def test_astream_events_from_model() -> None:
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -595,7 +658,8 @@ async def test_astream_events_from_model() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
@RunnableLambda
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
@ -605,7 +669,9 @@ async def test_astream_events_from_model() -> None:
return await model.ainvoke(input, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -618,7 +684,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -627,7 +697,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -636,7 +710,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -645,7 +723,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -670,7 +752,11 @@ async def test_astream_events_from_model() -> None:
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -694,7 +780,8 @@ async def test_astream_events_from_model() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_simple_chain() -> None:
@ -733,7 +820,9 @@ async def test_event_stream_with_simple_chain() -> None:
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v1")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
@ -909,7 +998,8 @@ async def test_event_stream_with_simple_chain() -> None:
"parent_ids": [],
"tags": ["my_chain"],
},
]
],
)
async def test_event_streaming_with_tools() -> None:
@ -938,7 +1028,9 @@ async def test_event_streaming_with_tools() -> None:
# type ignores below because the tools don't appear to be runnables to type checkers
# we can remove as soon as that's fixed
events = await _collect_events(parameterless.astream_events({}, version="v1")) # type: ignore
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {}},
"event": "on_tool_start",
@ -966,10 +1058,13 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
events = await _collect_events(with_callbacks.astream_events({}, version="v1")) # type: ignore
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {}},
"event": "on_tool_start",
@ -997,11 +1092,14 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
events = await _collect_events(
with_parameters.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"x": 1, "y": "2"}},
"event": "on_tool_start",
@ -1029,12 +1127,15 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
events = await _collect_events(
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"x": 1, "y": "2"}},
"event": "on_tool_start",
@ -1062,7 +1163,8 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
class HardCodedRetriever(BaseRetriever):
@ -1091,7 +1193,9 @@ async def test_event_stream_with_retriever() -> None:
events = await _collect_events(
retriever.astream_events({"query": "hello"}, version="v1")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {
"input": {"query": "hello"},
@ -1107,7 +1211,9 @@ async def test_event_stream_with_retriever() -> None:
"data": {
"chunk": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
]
},
"event": "on_retriever_stream",
@ -1121,7 +1227,9 @@ async def test_event_stream_with_retriever() -> None:
"data": {
"output": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
],
},
"event": "on_retriever_end",
@ -1131,7 +1239,8 @@ async def test_event_stream_with_retriever() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_retriever_and_formatter() -> None:
@ -1155,7 +1264,9 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
chain = retriever | format_docs
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -1179,9 +1290,12 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
"input": {"query": "hello"},
"output": {
"documents": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
page_content="hello world!", metadata={"foo": "bar"}
),
Document(
page_content="goodbye world!",
metadata={"food": "spare"},
),
]
},
@ -1224,7 +1338,9 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
"data": {
"input": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
],
"output": "hello world!, goodbye world!",
},
@ -1244,7 +1360,8 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_on_chain_with_tool() -> None:
@ -1266,7 +1383,9 @@ async def test_event_stream_on_chain_with_tool() -> None:
events = await _collect_events(
chain.astream_events({"a": "hello", "b": "world"}, version="v1")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"a": "hello", "b": "world"}},
"event": "on_chain_start",
@ -1339,7 +1458,8 @@ async def test_event_stream_on_chain_with_tool() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
@pytest.mark.xfail(reason="Fix order of callback invocations in RunnableSequence")
@ -1368,7 +1488,9 @@ async def test_chain_ordering() -> None:
for event in events:
event["tags"] = sorted(event["tags"])
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "q"},
"event": "on_chain_start",
@ -1450,7 +1572,8 @@ async def test_chain_ordering() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_retry() -> None:
@ -1481,7 +1604,9 @@ async def test_event_stream_with_retry() -> None:
for event in events:
event["tags"] = sorted(event["tags"])
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "q"},
"event": "on_chain_start",
@ -1536,7 +1661,8 @@ async def test_event_stream_with_retry() -> None:
"parent_ids": [],
"tags": ["seq:step:2"],
},
]
],
)
async def test_with_llm() -> None:
@ -1550,7 +1676,9 @@ async def test_with_llm() -> None:
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v1")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
@ -1588,7 +1716,9 @@ async def test_with_llm() -> None:
},
{
"data": {
"input": {"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]}
"input": {
"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]
}
},
"event": "on_llm_start",
"metadata": {},
@ -1604,7 +1734,13 @@ async def test_with_llm() -> None:
},
"output": {
"generations": [
[{"generation_info": None, "text": "abc", "type": "Generation"}]
[
{
"generation_info": None,
"text": "abc",
"type": "Generation",
}
]
],
"llm_output": None,
"run": None,
@ -1653,7 +1789,8 @@ async def test_with_llm() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_runnable_each() -> None:
@ -1686,7 +1823,9 @@ async def test_events_astream_config() -> None:
assert model_02.invoke("hello") == AIMessage(content="Goodbye world", id="ai2")
events = await _collect_events(model_02.astream_events("hello", version="v1"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
@ -1732,7 +1871,8 @@ async def test_events_astream_config() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_runnable_with_message_history() -> None:
@ -1886,7 +2026,7 @@ async def test_sync_in_async_stream_lambdas() -> None:
add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore
events = await _collect_events(add_one_proxy.astream_events(1, version="v1"))
assert events == EXPECTED_EVENTS
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
async def test_async_in_async_stream_lambdas() -> None:
@ -1906,7 +2046,7 @@ async def test_async_in_async_stream_lambdas() -> None:
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
events = await _collect_events(add_one_proxy_.astream_events(1, version="v1"))
assert events == EXPECTED_EVENTS
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
@pytest.mark.xfail(
@ -1931,4 +2071,4 @@ async def test_sync_in_sync_lambdas() -> None:
add_one_proxy_ = RunnableLambda(add_one_proxy)
events = await _collect_events(add_one_proxy_.astream_events(1, version="v1"))
assert events == EXPECTED_EVENTS
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)

View File

@ -50,6 +50,9 @@ from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.runnables.test_runnable_events_v1 import (
_assert_events_equal_allow_superset_metadata,
)
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
@ -106,7 +109,9 @@ async def test_event_stream_with_simple_function_tool() -> None:
chain = RunnableLambda(foo) | get_docs
events = await _collect_events(chain.astream_events({}, version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"event": "on_chain_start",
"run_id": "",
@ -179,7 +184,8 @@ async def test_event_stream_with_simple_function_tool() -> None:
"metadata": {},
"data": {"output": [Document(page_content="hello")]},
},
]
],
)
async def test_event_stream_with_single_lambda() -> None:
@ -192,7 +198,9 @@ async def test_event_stream_with_single_lambda() -> None:
chain = RunnableLambda(func=reverse)
events = await _collect_events(chain.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -220,7 +228,8 @@ async def test_event_stream_with_single_lambda() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_triple_lambda() -> None:
@ -236,7 +245,9 @@ async def test_event_stream_with_triple_lambda() -> None:
| r.with_config({"run_name": "3"})
)
events = await _collect_events(chain.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -345,7 +356,8 @@ async def test_event_stream_with_triple_lambda() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_exception() -> None:
@ -381,7 +393,9 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
events = await _collect_events(
chain.astream_events("hello", include_names=["1"], version="v2")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -409,14 +423,17 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"parent_ids": [],
"tags": ["seq:step:1"],
},
]
],
)
events = await _collect_events(
chain.astream_events(
"hello", include_tags=["my_tag"], exclude_names=["2"], version="v2"
)
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -444,7 +461,8 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"parent_ids": [],
"tags": ["my_tag", "seq:step:3"],
},
]
],
)
async def test_event_stream_with_lambdas_from_lambda() -> None:
@ -454,7 +472,9 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
events = await _collect_events(
as_lambdas.astream_events({"question": "hello"}, version="v2")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
@ -482,7 +502,8 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_astream_events_from_model() -> None:
@ -501,11 +522,17 @@ async def test_astream_events_from_model() -> None:
.bind(stop="<stop_token>")
)
events = await _collect_events(model.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -514,7 +541,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -523,7 +554,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -532,7 +567,11 @@ async def test_astream_events_from_model() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -543,13 +582,18 @@ async def test_astream_events_from_model() -> None:
"output": _AnyIdAIMessageChunk(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
"tags": ["my_model"],
},
]
],
)
async def test_astream_with_model_in_chain() -> None:
@ -576,7 +620,9 @@ async def test_astream_with_model_in_chain() -> None:
return model.invoke(input, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -589,7 +635,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -598,7 +648,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -607,7 +661,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -616,7 +674,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -628,7 +690,11 @@ async def test_astream_with_model_in_chain() -> None:
"output": _AnyIdAIMessage(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -652,7 +718,8 @@ async def test_astream_with_model_in_chain() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
@RunnableLambda
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
@ -662,7 +729,9 @@ async def test_astream_with_model_in_chain() -> None:
return await model.ainvoke(input, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -675,7 +744,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -684,7 +757,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -693,7 +770,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -702,7 +783,11 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -714,7 +799,11 @@ async def test_astream_with_model_in_chain() -> None:
"output": _AnyIdAIMessage(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"metadata": {
"a": "b",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"parent_ids": [],
@ -738,7 +827,8 @@ async def test_astream_with_model_in_chain() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_simple_chain() -> None:
@ -777,7 +867,9 @@ async def test_event_stream_with_simple_chain() -> None:
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v2")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
@ -938,7 +1030,8 @@ async def test_event_stream_with_simple_chain() -> None:
"parent_ids": [],
"tags": ["my_chain"],
},
]
],
)
async def test_event_streaming_with_tools() -> None:
@ -967,7 +1060,9 @@ async def test_event_streaming_with_tools() -> None:
# type ignores below because the tools don't appear to be runnables to type checkers
# we can remove as soon as that's fixed
events = await _collect_events(parameterless.astream_events({}, version="v2")) # type: ignore
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {}},
"event": "on_tool_start",
@ -986,9 +1081,12 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
events = await _collect_events(with_callbacks.astream_events({}, version="v2")) # type: ignore
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {}},
"event": "on_tool_start",
@ -1007,11 +1105,14 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
events = await _collect_events(
with_parameters.astream_events({"x": 1, "y": "2"}, version="v2") # type: ignore
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"x": 1, "y": "2"}},
"event": "on_tool_start",
@ -1030,12 +1131,15 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
events = await _collect_events(
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v2") # type: ignore
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"x": 1, "y": "2"}},
"event": "on_tool_start",
@ -1054,7 +1158,8 @@ async def test_event_streaming_with_tools() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
class HardCodedRetriever(BaseRetriever):
@ -1083,7 +1188,9 @@ async def test_event_stream_with_retriever() -> None:
events = await _collect_events(
retriever.astream_events({"query": "hello"}, version="v2")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {
"input": {"query": "hello"},
@ -1099,7 +1206,9 @@ async def test_event_stream_with_retriever() -> None:
"data": {
"output": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
]
},
"event": "on_retriever_end",
@ -1109,7 +1218,8 @@ async def test_event_stream_with_retriever() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_retriever_and_formatter() -> None:
@ -1133,7 +1243,9 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
chain = retriever | format_docs
events = await _collect_events(chain.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -1157,7 +1269,9 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
"input": {"query": "hello"},
"output": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
],
},
"event": "on_retriever_end",
@ -1198,7 +1312,9 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
"data": {
"input": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
],
"output": "hello world!, goodbye world!",
},
@ -1218,7 +1334,8 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_on_chain_with_tool() -> None:
@ -1240,7 +1357,9 @@ async def test_event_stream_on_chain_with_tool() -> None:
events = await _collect_events(
chain.astream_events({"a": "hello", "b": "world"}, version="v2")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"a": "hello", "b": "world"}},
"event": "on_chain_start",
@ -1313,7 +1432,8 @@ async def test_event_stream_on_chain_with_tool() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
@pytest.mark.xfail(reason="Fix order of callback invocations in RunnableSequence")
@ -1342,7 +1462,9 @@ async def test_chain_ordering() -> None:
for event in events:
event["tags"] = sorted(event["tags"])
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "q"},
"event": "on_chain_start",
@ -1424,7 +1546,8 @@ async def test_chain_ordering() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_event_stream_with_retry() -> None:
@ -1455,7 +1578,9 @@ async def test_event_stream_with_retry() -> None:
for event in events:
event["tags"] = sorted(event["tags"])
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "q"},
"event": "on_chain_start",
@ -1501,7 +1626,8 @@ async def test_event_stream_with_retry() -> None:
"parent_ids": [],
"tags": ["seq:step:1"],
},
]
],
)
async def test_with_llm() -> None:
@ -1515,7 +1641,9 @@ async def test_with_llm() -> None:
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v2")
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
@ -1553,7 +1681,9 @@ async def test_with_llm() -> None:
},
{
"data": {
"input": {"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]}
"input": {
"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]
}
},
"event": "on_llm_start",
"metadata": {},
@ -1569,7 +1699,13 @@ async def test_with_llm() -> None:
},
"output": {
"generations": [
[{"generation_info": None, "text": "abc", "type": "Generation"}]
[
{
"generation_info": None,
"text": "abc",
"type": "Generation",
}
]
],
"llm_output": None,
},
@ -1617,7 +1753,8 @@ async def test_with_llm() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_runnable_each() -> None:
@ -1650,7 +1787,9 @@ async def test_events_astream_config() -> None:
assert model_02.invoke("hello") == AIMessage(content="Goodbye world", id="ai2")
events = await _collect_events(model_02.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
@ -1698,7 +1837,8 @@ async def test_events_astream_config() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_runnable_with_message_history() -> None:
@ -1847,7 +1987,7 @@ async def test_sync_in_async_stream_lambdas() -> None:
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
async def test_async_in_async_stream_lambdas() -> None:
@ -1867,7 +2007,7 @@ async def test_async_in_async_stream_lambdas() -> None:
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
async def test_sync_in_sync_lambdas() -> None:
@ -1887,7 +2027,7 @@ async def test_sync_in_sync_lambdas() -> None:
add_one_proxy_ = RunnableLambda(add_one_proxy)
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
class StreamingRunnable(Runnable[Input, Output]):
@ -1955,7 +2095,9 @@ async def test_astream_events_from_custom_runnable() -> None:
chunks = [chunk async for chunk in runnable.astream(1, version="v2")]
assert chunks == ["1", "2", "3"]
events = await _collect_events(runnable.astream_events(1, version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": 1},
"event": "on_chain_start",
@ -2001,7 +2143,8 @@ async def test_astream_events_from_custom_runnable() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_parent_run_id_assignment() -> None:
@ -2029,7 +2172,9 @@ async def test_parent_run_id_assignment() -> None:
parent.astream_events("hello", {"run_id": bond}, version="v2"),
with_nulled_ids=False,
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -2099,7 +2244,8 @@ async def test_parent_run_id_assignment() -> None:
"run_id": "00000000-0000-0000-0000-000000000007",
"tags": [],
},
]
],
)
async def test_bad_parent_ids() -> None:
@ -2125,7 +2271,9 @@ async def test_bad_parent_ids() -> None:
# Includes only a partial list of events since the run ID gets duplicated
# between parent and child run ID and the callback handler throws an exception.
# The exception does not get bubbled up to the user.
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -2135,7 +2283,8 @@ async def test_bad_parent_ids() -> None:
"run_id": "00000000-0000-0000-0000-000000000007",
"tags": [],
}
]
],
)
async def test_runnable_generator() -> None:
@ -2147,7 +2296,9 @@ async def test_runnable_generator() -> None:
runnable: Runnable[str, str] = RunnableGenerator(transform=generator)
events = await _collect_events(runnable.astream_events("hello", version="v2"))
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": "hello"},
"event": "on_chain_start",
@ -2184,7 +2335,8 @@ async def test_runnable_generator() -> None:
"parent_ids": [],
"tags": [],
},
]
],
)
async def test_with_explicit_config() -> None:
@ -2380,7 +2532,9 @@ async def test_custom_event() -> None:
)
run_id = str(uuid1)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": 1},
"event": "on_chain_start",
@ -2426,7 +2580,8 @@ async def test_custom_event() -> None:
"run_id": run_id,
"tags": [],
},
]
],
)
async def test_custom_event_nested() -> None:
@ -2467,7 +2622,9 @@ async def test_custom_event_nested() -> None:
run_id = str(run_id) # type: ignore[assignment]
child_run_id = str(child_run_id) # type: ignore[assignment]
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": 1},
"event": "on_chain_start",
@ -2531,7 +2688,8 @@ async def test_custom_event_nested() -> None:
"run_id": "00000000-0000-0000-0000-000000000007",
"tags": [],
},
]
],
)
async def test_custom_event_root_dispatch() -> None:
@ -2566,7 +2724,9 @@ async def test_custom_event_root_dispatch_with_in_tool() -> None:
events = await _collect_events(
foo.astream_events({"x": 2}, version="v2") # type: ignore[attr-defined]
)
assert events == [
_assert_events_equal_allow_superset_metadata(
events,
[
{
"data": {"input": {"x": 2}},
"event": "on_tool_start",
@ -2594,4 +2754,5 @@ async def test_custom_event_root_dispatch_with_in_tool() -> None:
"run_id": "",
"tags": [],
},
]
],
)