diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 5fcefcd9571..feff8a4a7ba 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -7,9 +7,13 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage -from langchain_core.messages.human import HumanMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import ( + AnyStr, + _AnyIdAIMessage, + _AnyIdAIMessageChunk, + _AnyIdHumanMessage, +) def test_generic_fake_chat_model_invoke() -> None: @@ -17,11 +21,11 @@ def test_generic_fake_chat_model_invoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = model.invoke("meow") - assert response == AIMessage(content="hello", id=AnyStr()) + assert response == _AnyIdAIMessage(content="hello") response = model.invoke("kitty") - assert response == AIMessage(content="goodbye", id=AnyStr()) + assert response == _AnyIdAIMessage(content="goodbye") response = model.invoke("meow") - assert response == AIMessage(content="hello", id=AnyStr()) + assert response == _AnyIdAIMessage(content="hello") async def test_generic_fake_chat_model_ainvoke() -> None: @@ -29,11 +33,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = await model.ainvoke("meow") - assert response == AIMessage(content="hello", id=AnyStr()) + assert response == _AnyIdAIMessage(content="hello") response = await model.ainvoke("kitty") - assert response == AIMessage(content="goodbye", id=AnyStr()) + assert response == _AnyIdAIMessage(content="goodbye") response = await model.ainvoke("meow") - assert response == AIMessage(content="hello", id=AnyStr()) + assert response == _AnyIdAIMessage(content="hello") async def test_generic_fake_chat_model_stream() -> None: @@ -46,17 +50,17 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=infinite_cycle) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - AIMessageChunk(content="hello", id=AnyStr()), - AIMessageChunk(content=" ", id=AnyStr()), - AIMessageChunk(content="goodbye", id=AnyStr()), + _AnyIdAIMessageChunk(content="hello"), + _AnyIdAIMessageChunk(content=" "), + _AnyIdAIMessageChunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 chunks = [chunk for chunk in model.stream("meow")] assert chunks == [ - AIMessageChunk(content="hello", id=AnyStr()), - AIMessageChunk(content=" ", id=AnyStr()), - AIMessageChunk(content="goodbye", id=AnyStr()), + _AnyIdAIMessageChunk(content="hello"), + _AnyIdAIMessageChunk(content=" "), + _AnyIdAIMessageChunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -141,9 +145,9 @@ async def test_generic_fake_chat_model_astream_log() -> None: ] final = log_patches[-1] assert final.state["streamed_output"] == [ - AIMessageChunk(content="hello", id=AnyStr()), - AIMessageChunk(content=" ", id=AnyStr()), - AIMessageChunk(content="goodbye", id=AnyStr()), + _AnyIdAIMessageChunk(content="hello"), + _AnyIdAIMessageChunk(content=" "), + _AnyIdAIMessageChunk(content="goodbye"), ] assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1 @@ -192,9 +196,9 @@ async def test_callback_handlers() -> None: # New model results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) assert results == [ - AIMessageChunk(content="hello", id=AnyStr()), - AIMessageChunk(content=" ", id=AnyStr()), - AIMessageChunk(content="goodbye", id=AnyStr()), + _AnyIdAIMessageChunk(content="hello"), + _AnyIdAIMessageChunk(content=" "), + _AnyIdAIMessageChunk(content="goodbye"), ] assert tokens == ["hello", " ", "goodbye"] assert len({chunk.id for chunk in results}) == 1 @@ -203,8 +207,6 @@ async def test_callback_handlers() -> None: def test_chat_model_inputs() -> None: fake = ParrotFakeChatModel() - assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr()) - assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr()) - assert fake.invoke([AIMessage(content="blah")]) == AIMessage( - content="blah", id=AnyStr() - ) + assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello") + assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah") + assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah") diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index a0d79ca9ea4..4c08a5441d1 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -24,7 +24,7 @@ from tests.unit_tests.fake.callbacks import ( FakeAsyncCallbackHandler, FakeCallbackHandler, ) -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk @pytest.fixture @@ -144,10 +144,10 @@ async def test_astream_fallback_to_ainvoke() -> None: model = ModelWithGenerate() chunks = [chunk for chunk in model.stream("anything")] - assert chunks == [AIMessage(content="hello", id=AnyStr())] + assert chunks == [_AnyIdAIMessage(content="hello")] chunks = [chunk async for chunk in model.astream("anything")] - assert chunks == [AIMessage(content="hello", id=AnyStr())] + assert chunks == [_AnyIdAIMessage(content="hello")] async def test_astream_implementation_fallback_to_stream() -> None: @@ -182,15 +182,15 @@ async def test_astream_implementation_fallback_to_stream() -> None: model = ModelWithSyncStream() chunks = [chunk for chunk in model.stream("anything")] assert chunks == [ - AIMessageChunk(content="a", id=AnyStr()), - AIMessageChunk(content="b", id=AnyStr()), + _AnyIdAIMessageChunk(content="a"), + _AnyIdAIMessageChunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 assert type(model)._astream == BaseChatModel._astream astream_chunks = [chunk async for chunk in model.astream("anything")] assert astream_chunks == [ - AIMessageChunk(content="a", id=AnyStr()), - AIMessageChunk(content="b", id=AnyStr()), + _AnyIdAIMessageChunk(content="a"), + _AnyIdAIMessageChunk(content="b"), ] assert len({chunk.id for chunk in astream_chunks}) == 1 @@ -227,8 +227,8 @@ async def test_astream_implementation_uses_astream() -> None: model = ModelWithAsyncStream() chunks = [chunk async for chunk in model.astream("anything")] assert chunks == [ - AIMessageChunk(content="a", id=AnyStr()), - AIMessageChunk(content="b", id=AnyStr()), + _AnyIdAIMessageChunk(content="a"), + _AnyIdAIMessageChunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index bfbd3d19d35..e490384d6e5 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -37,7 +37,6 @@ from langchain_core.language_models import ( from langchain_core.load import dumpd, dumps from langchain_core.load.load import loads from langchain_core.messages import ( - AIMessage, AIMessageChunk, HumanMessage, SystemMessage, @@ -90,7 +89,7 @@ from langchain_core.tracers import ( RunLogPatch, ) from langchain_core.tracers.context import collect_runs -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk class FakeTracer(BaseTracer): @@ -1825,7 +1824,7 @@ def test_prompt_with_chat_model( tracer = FakeTracer() assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == AIMessage(content="foo", id=AnyStr()) + ) == _AnyIdAIMessage(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1850,8 +1849,8 @@ def test_prompt_with_chat_model( ], dict(callbacks=[tracer]), ) == [ - AIMessage(content="foo", id=AnyStr()), - AIMessage(content="foo", id=AnyStr()), + _AnyIdAIMessage(content="foo"), + _AnyIdAIMessage(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1891,9 +1890,9 @@ def test_prompt_with_chat_model( assert [ *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) ] == [ - AIMessageChunk(content="f", id=AnyStr()), - AIMessageChunk(content="o", id=AnyStr()), - AIMessageChunk(content="o", id=AnyStr()), + _AnyIdAIMessageChunk(content="f"), + _AnyIdAIMessageChunk(content="o"), + _AnyIdAIMessageChunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -1931,7 +1930,7 @@ async def test_prompt_with_chat_model_async( tracer = FakeTracer() assert await chain.ainvoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == AIMessage(content="foo", id=AnyStr()) + ) == _AnyIdAIMessage(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1956,8 +1955,8 @@ async def test_prompt_with_chat_model_async( ], dict(callbacks=[tracer]), ) == [ - AIMessage(content="foo", id=AnyStr()), - AIMessage(content="foo", id=AnyStr()), + _AnyIdAIMessage(content="foo"), + _AnyIdAIMessage(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -2000,9 +1999,9 @@ async def test_prompt_with_chat_model_async( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) ] == [ - AIMessageChunk(content="f", id=AnyStr()), - AIMessageChunk(content="o", id=AnyStr()), - AIMessageChunk(content="o", id=AnyStr()), + _AnyIdAIMessageChunk(content="f"), + _AnyIdAIMessageChunk(content="o"), + _AnyIdAIMessageChunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -2669,7 +2668,7 @@ def test_prompt_with_chat_model_and_parser( HumanMessage(content="What is your name?"), ] ) - assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr()) + assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") assert tracer.runs == snapshot @@ -2804,7 +2803,7 @@ What is your name?""" ), ] ) - assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr()) + assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 4 @@ -2850,7 +2849,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": AIMessage(content="i'm a chatbot", id=AnyStr()), + "chat": _AnyIdAIMessage(content="i'm a chatbot"), "llm": "i'm a textbot", } assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -3060,7 +3059,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": AIMessage(content="i'm a chatbot", id=AnyStr()), + "chat": _AnyIdAIMessage(content="i'm a chatbot"), "llm": "i'm a textbot", "passthrough": ChatPromptValue( messages=[ @@ -3269,7 +3268,7 @@ async def test_map_astream() -> None: assert streamed_chunks[0] in [ {"passthrough": prompt.invoke({"question": "What is your name?"})}, {"llm": "i"}, - {"chat": AIMessageChunk(content="i", id=AnyStr())}, + {"chat": _AnyIdAIMessageChunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert all(len(c.keys()) == 1 for c in streamed_chunks) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 47e417d0410..c73ca1e6482 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -30,7 +30,7 @@ from langchain_core.runnables import ( from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import tool -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: @@ -461,7 +461,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -470,7 +470,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -479,7 +479,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -488,7 +488,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())}, + "data": {"output": _AnyIdAIMessageChunk(content="hello world!")}, "event": "on_chat_model_end", "metadata": {"a": "b"}, "name": "my_model", @@ -526,7 +526,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -535,7 +535,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -544,7 +544,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -560,9 +560,7 @@ async def test_astream_events_from_model() -> None: [ { "generation_info": None, - "message": AIMessage( - content="hello world!", id=AnyStr() - ), + "message": _AnyIdAIMessage(content="hello world!"), "text": "hello world!", "type": "ChatGeneration", } @@ -580,7 +578,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -589,7 +587,7 @@ async def test_astream_events_from_model() -> None: "tags": [], }, { - "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"output": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -627,7 +625,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -636,7 +634,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -645,7 +643,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -661,9 +659,7 @@ async def test_astream_events_from_model() -> None: [ { "generation_info": None, - "message": AIMessage( - content="hello world!", id=AnyStr() - ), + "message": _AnyIdAIMessage(content="hello world!"), "text": "hello world!", "type": "ChatGeneration", } @@ -681,7 +677,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -690,7 +686,7 @@ async def test_astream_events_from_model() -> None: "tags": [], }, { - "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"output": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 69bdfa5ed01..f4f83a95b0a 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -50,7 +50,7 @@ from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import Input, Output from langchain_core.tools import tool from langchain_core.utils.aiter import aclosing -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: @@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -540,7 +540,7 @@ async def test_astream_events_from_model() -> None: }, { "data": { - "output": AIMessageChunk(content="hello world!", id=AnyStr()), + "output": _AnyIdAIMessageChunk(content="hello world!"), }, "event": "on_chat_model_end", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, @@ -596,7 +596,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -605,7 +605,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -614,7 +614,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -625,7 +625,7 @@ async def test_astream_with_model_in_chain() -> None: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": AIMessage(content="hello world!", id=AnyStr()), + "output": _AnyIdAIMessage(content="hello world!"), }, "event": "on_chat_model_end", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, @@ -635,7 +635,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -644,7 +644,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": [], }, { - "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"output": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -682,7 +682,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -691,7 +691,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -700,7 +700,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, "name": "my_model", @@ -711,7 +711,7 @@ async def test_astream_with_model_in_chain() -> None: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": AIMessage(content="hello world!", id=AnyStr()), + "output": _AnyIdAIMessage(content="hello world!"), }, "event": "on_chat_model_end", "metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": ""}, @@ -721,7 +721,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": ["my_model"], }, { - "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -730,7 +730,7 @@ async def test_astream_with_model_in_chain() -> None: "tags": [], }, { - "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "data": {"output": _AnyIdAIMessage(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/stubs.py b/libs/core/tests/unit_tests/stubs.py index 38e84a3a7aa..b752364e3af 100644 --- a/libs/core/tests/unit_tests/stubs.py +++ b/libs/core/tests/unit_tests/stubs.py @@ -1,6 +1,44 @@ from typing import Any +from langchain_core.documents import Document +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage + class AnyStr(str): def __eq__(self, other: Any) -> bool: return isinstance(other, str) + + +# The code below creates version of pydantic models +# that will work in unit tests with AnyStr as id field +# Please note that the `id` field is assigned AFTER the model is created +# to workaround an issue with pydantic ignoring the __eq__ method on +# subclassed strings. + + +def _AnyIdDocument(**kwargs: Any) -> Document: + """Create a document with an id field.""" + message = Document(**kwargs) + message.id = AnyStr() + return message + + +def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: + """Create ai message with an any id field.""" + message = AIMessage(**kwargs) + message.id = AnyStr() + return message + + +def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk: + """Create ai message with an any id field.""" + message = AIMessageChunk(**kwargs) + message.id = AnyStr() + return message + + +def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: + """Create a human with an any id field.""" + message = HumanMessage(**kwargs) + message.id = AnyStr() + return message diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index 8a3a3b6407d..8d30afac3ca 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import ( from langchain_core.documents import Document from langchain_core.embeddings.fake import DeterministicFakeEmbedding from langchain_core.vectorstores import InMemoryVectorStore -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import AnyStr, _AnyIdDocument class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite): @@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None: # Check sync version output = store.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", id=AnyStr())] + assert output == [_AnyIdDocument(page_content="foo")] # Check async version output = await store.asimilarity_search("bar", k=2) assert output == [ - Document(page_content="bar", id=AnyStr()), - Document(page_content="baz", id=AnyStr()), + _AnyIdDocument(page_content="bar"), + _AnyIdDocument(page_content="baz"), ] @@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None: # make sure we can k > docstore size output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1) assert len(output) == len(texts) - assert output[0] == Document(page_content="foo", id=AnyStr()) - assert output[1] == Document(page_content="foy", id=AnyStr()) + assert output[0] == _AnyIdDocument(page_content="foo") + assert output[1] == _AnyIdDocument(page_content="foy") # Check async version output = await docsearch.amax_marginal_relevance_search( "foo", k=10, lambda_mult=0.1 ) assert len(output) == len(texts) - assert output[0] == Document(page_content="foo", id=AnyStr()) - assert output[1] == Document(page_content="foy", id=AnyStr()) + assert output[0] == _AnyIdDocument(page_content="foo") + assert output[1] == _AnyIdDocument(page_content="foy") async def test_inmemory_dump_load(tmp_path: Path) -> None: