core[patch]: Update unit tests with a workaround for using AnyID in pydantic 2 (#24892)

Pydantic 2 ignores __eq__ overload for subclasses of strings.
This commit is contained in:
Eugene Yurtsev 2024-07-31 14:42:12 -04:00 committed by GitHub
parent 8461934c2b
commit 5099a9c9b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 134 additions and 99 deletions

View File

@ -7,9 +7,13 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import (
AnyStr,
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
)
def test_generic_fake_chat_model_invoke() -> None:
@ -17,11 +21,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye", id=AnyStr())
assert response == _AnyIdAIMessage(content="goodbye")
response = model.invoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
@ -29,11 +33,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye", id=AnyStr())
assert response == _AnyIdAIMessage(content="goodbye")
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
async def test_generic_fake_chat_model_stream() -> None:
@ -46,17 +50,17 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
assert len({chunk.id for chunk in chunks}) == 1
@ -141,9 +145,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
@ -192,9 +196,9 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
assert len({chunk.id for chunk in results}) == 1
@ -203,8 +207,6 @@ async def test_callback_handlers() -> None:
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
content="blah", id=AnyStr()
)
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")

View File

@ -24,7 +24,7 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
@pytest.fixture
@ -144,10 +144,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [AIMessage(content="hello", id=AnyStr())]
assert chunks == [_AnyIdAIMessage(content="hello")]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [AIMessage(content="hello", id=AnyStr())]
assert chunks == [_AnyIdAIMessage(content="hello")]
async def test_astream_implementation_fallback_to_stream() -> None:
@ -182,15 +182,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
]
assert len({chunk.id for chunk in chunks}) == 1
assert type(model)._astream == BaseChatModel._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == [
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
]
assert len({chunk.id for chunk in astream_chunks}) == 1
@ -227,8 +227,8 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
]
assert len({chunk.id for chunk in chunks}) == 1

View File

@ -37,7 +37,6 @@ from langchain_core.language_models import (
from langchain_core.load import dumpd, dumps
from langchain_core.load.load import loads
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
@ -90,7 +89,7 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
class FakeTracer(BaseTracer):
@ -1825,7 +1824,7 @@ def test_prompt_with_chat_model(
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo", id=AnyStr())
) == _AnyIdAIMessage(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -1850,8 +1849,8 @@ def test_prompt_with_chat_model(
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="foo", id=AnyStr()),
AIMessage(content="foo", id=AnyStr()),
_AnyIdAIMessage(content="foo"),
_AnyIdAIMessage(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -1891,9 +1890,9 @@ def test_prompt_with_chat_model(
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [
AIMessageChunk(content="f", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
_AnyIdAIMessageChunk(content="f"),
_AnyIdAIMessageChunk(content="o"),
_AnyIdAIMessageChunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1931,7 +1930,7 @@ async def test_prompt_with_chat_model_async(
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo", id=AnyStr())
) == _AnyIdAIMessage(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -1956,8 +1955,8 @@ async def test_prompt_with_chat_model_async(
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="foo", id=AnyStr()),
AIMessage(content="foo", id=AnyStr()),
_AnyIdAIMessage(content="foo"),
_AnyIdAIMessage(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -2000,9 +1999,9 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [
AIMessageChunk(content="f", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
_AnyIdAIMessageChunk(content="f"),
_AnyIdAIMessageChunk(content="o"),
_AnyIdAIMessageChunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2669,7 +2668,7 @@ def test_prompt_with_chat_model_and_parser(
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
assert tracer.runs == snapshot
@ -2804,7 +2803,7 @@ What is your name?"""
),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4
@ -2850,7 +2849,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
"llm": "i'm a textbot",
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -3060,7 +3059,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
"llm": "i'm a textbot",
"passthrough": ChatPromptValue(
messages=[
@ -3269,7 +3268,7 @@ async def test_map_astream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": AIMessageChunk(content="i", id=AnyStr())},
{"chat": _AnyIdAIMessageChunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)

View File

@ -30,7 +30,7 @@ from langchain_core.runnables import (
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
@ -461,7 +461,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -470,7 +470,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -479,7 +479,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -488,7 +488,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())},
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"name": "my_model",
@ -526,7 +526,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -535,7 +535,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -544,7 +544,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -560,9 +560,7 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"message": _AnyIdAIMessage(content="hello world!"),
"text": "hello world!",
"type": "ChatGeneration",
}
@ -580,7 +578,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@ -589,7 +587,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@ -627,7 +625,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -636,7 +634,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -645,7 +643,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -661,9 +659,7 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"message": _AnyIdAIMessage(content="hello world!"),
"text": "hello world!",
"type": "ChatGeneration",
}
@ -681,7 +677,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@ -690,7 +686,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",

View File

@ -50,7 +50,7 @@ from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -540,7 +540,7 @@ async def test_astream_events_from_model() -> None:
},
{
"data": {
"output": AIMessageChunk(content="hello world!", id=AnyStr()),
"output": _AnyIdAIMessageChunk(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
@ -596,7 +596,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -605,7 +605,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -614,7 +614,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -625,7 +625,7 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": AIMessage(content="hello world!", id=AnyStr()),
"output": _AnyIdAIMessage(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
@ -635,7 +635,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@ -644,7 +644,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@ -682,7 +682,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -691,7 +691,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -700,7 +700,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@ -711,7 +711,7 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": AIMessage(content="hello world!", id=AnyStr()),
"output": _AnyIdAIMessage(content="hello world!"),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
@ -721,7 +721,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@ -730,7 +730,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",

View File

@ -1,6 +1,44 @@
from typing import Any
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field
# Please note that the `id` field is assigned AFTER the model is created
# to workaround an issue with pydantic ignoring the __eq__ method on
# subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message

View File

@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
from langchain_core.documents import Document
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.vectorstores import InMemoryVectorStore
from tests.unit_tests.stubs import AnyStr
from tests.unit_tests.stubs import AnyStr, _AnyIdDocument
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
# Check sync version
output = store.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
assert output == [_AnyIdDocument(page_content="foo")]
# Check async version
output = await store.asimilarity_search("bar", k=2)
assert output == [
Document(page_content="bar", id=AnyStr()),
Document(page_content="baz", id=AnyStr()),
_AnyIdDocument(page_content="bar"),
_AnyIdDocument(page_content="baz"),
]
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
# make sure we can k > docstore size
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0] == Document(page_content="foo", id=AnyStr())
assert output[1] == Document(page_content="foy", id=AnyStr())
assert output[0] == _AnyIdDocument(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy")
# Check async version
output = await docsearch.amax_marginal_relevance_search(
"foo", k=10, lambda_mult=0.1
)
assert len(output) == len(texts)
assert output[0] == Document(page_content="foo", id=AnyStr())
assert output[1] == Document(page_content="foy", id=AnyStr())
assert output[0] == _AnyIdDocument(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy")
async def test_inmemory_dump_load(tmp_path: Path) -> None: