mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
core[patch]: Update unit tests with a workaround for using AnyID in pydantic 2 (#24892)
Pydantic 2 ignores __eq__ overload for subclasses of strings.
This commit is contained in:
parent
8461934c2b
commit
5099a9c9b4
@ -7,9 +7,13 @@ from uuid import UUID
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import (
|
||||
AnyStr,
|
||||
_AnyIdAIMessage,
|
||||
_AnyIdAIMessageChunk,
|
||||
_AnyIdHumanMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@ -17,11 +21,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="goodbye")
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@ -29,11 +33,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="goodbye")
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
assert response == _AnyIdAIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
@ -46,17 +50,17 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@ -141,9 +145,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
||||
|
||||
@ -192,9 +196,9 @@ async def test_callback_handlers() -> None:
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="hello"),
|
||||
_AnyIdAIMessageChunk(content=" "),
|
||||
_AnyIdAIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
assert len({chunk.id for chunk in results}) == 1
|
||||
@ -203,8 +207,6 @@ async def test_callback_handlers() -> None:
|
||||
def test_chat_model_inputs() -> None:
|
||||
fake = ParrotFakeChatModel()
|
||||
|
||||
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
|
||||
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
|
||||
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
|
||||
content="blah", id=AnyStr()
|
||||
)
|
||||
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
|
||||
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
|
||||
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")
|
||||
|
@ -24,7 +24,7 @@ from tests.unit_tests.fake.callbacks import (
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -144,10 +144,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
||||
|
||||
model = ModelWithGenerate()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [AIMessage(content="hello", id=AnyStr())]
|
||||
assert chunks == [_AnyIdAIMessage(content="hello")]
|
||||
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [AIMessage(content="hello", id=AnyStr())]
|
||||
assert chunks == [_AnyIdAIMessage(content="hello")]
|
||||
|
||||
|
||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
@ -182,15 +182,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
model = ModelWithSyncStream()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="a", id=AnyStr()),
|
||||
AIMessageChunk(content="b", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="a"),
|
||||
_AnyIdAIMessageChunk(content="b"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
assert type(model)._astream == BaseChatModel._astream
|
||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert astream_chunks == [
|
||||
AIMessageChunk(content="a", id=AnyStr()),
|
||||
AIMessageChunk(content="b", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="a"),
|
||||
_AnyIdAIMessageChunk(content="b"),
|
||||
]
|
||||
assert len({chunk.id for chunk in astream_chunks}) == 1
|
||||
|
||||
@ -227,8 +227,8 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="a", id=AnyStr()),
|
||||
AIMessageChunk(content="b", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="a"),
|
||||
_AnyIdAIMessageChunk(content="b"),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
|
@ -37,7 +37,6 @@ from langchain_core.language_models import (
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.load.load import loads
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
@ -90,7 +89,7 @@ from langchain_core.tracers import (
|
||||
RunLogPatch,
|
||||
)
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
@ -1825,7 +1824,7 @@ def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo", id=AnyStr())
|
||||
) == _AnyIdAIMessage(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -1850,8 +1849,8 @@ def test_prompt_with_chat_model(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
@ -1891,9 +1890,9 @@ def test_prompt_with_chat_model(
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [
|
||||
AIMessageChunk(content="f", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="f"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
@ -1931,7 +1930,7 @@ async def test_prompt_with_chat_model_async(
|
||||
tracer = FakeTracer()
|
||||
assert await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo", id=AnyStr())
|
||||
) == _AnyIdAIMessage(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -1956,8 +1955,8 @@ async def test_prompt_with_chat_model_async(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
_AnyIdAIMessage(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
@ -2000,9 +1999,9 @@ async def test_prompt_with_chat_model_async(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == [
|
||||
AIMessageChunk(content="f", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="f"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
_AnyIdAIMessageChunk(content="o"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
@ -2669,7 +2668,7 @@ def test_prompt_with_chat_model_and_parser(
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
|
||||
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
||||
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
@ -2804,7 +2803,7 @@ What is your name?"""
|
||||
),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
|
||||
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 4
|
||||
@ -2850,7 +2849,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
|
||||
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
}
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
@ -3060,7 +3059,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
|
||||
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
"passthrough": ChatPromptValue(
|
||||
messages=[
|
||||
@ -3269,7 +3268,7 @@ async def test_map_astream() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": AIMessageChunk(content="i", id=AnyStr())},
|
||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
|
@ -30,7 +30,7 @@ from langchain_core.runnables import (
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import tool
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
||||
@ -461,7 +461,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -470,7 +470,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -479,7 +479,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -488,7 +488,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())},
|
||||
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -526,7 +526,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -535,7 +535,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -544,7 +544,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -560,9 +560,7 @@ async def test_astream_events_from_model() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessage(
|
||||
content="hello world!", id=AnyStr()
|
||||
),
|
||||
"message": _AnyIdAIMessage(content="hello world!"),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
@ -580,7 +578,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -589,7 +587,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -627,7 +625,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -636,7 +634,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -645,7 +643,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -661,9 +659,7 @@ async def test_astream_events_from_model() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessage(
|
||||
content="hello world!", id=AnyStr()
|
||||
),
|
||||
"message": _AnyIdAIMessage(content="hello world!"),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
@ -681,7 +677,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
@ -690,7 +686,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
|
@ -50,7 +50,7 @@ from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.aiter import aclosing
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
|
||||
|
||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
||||
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -540,7 +540,7 @@ async def test_astream_events_from_model() -> None:
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"output": AIMessageChunk(content="hello world!", id=AnyStr()),
|
||||
"output": _AnyIdAIMessageChunk(content="hello world!"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
@ -596,7 +596,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -605,7 +605,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -614,7 +614,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -625,7 +625,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {
|
||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||
"output": AIMessage(content="hello world!", id=AnyStr()),
|
||||
"output": _AnyIdAIMessage(content="hello world!"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
@ -635,7 +635,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -644,7 +644,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -682,7 +682,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -691,7 +691,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -700,7 +700,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
@ -711,7 +711,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {
|
||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||
"output": AIMessage(content="hello world!", id=AnyStr()),
|
||||
"output": _AnyIdAIMessage(content="hello world!"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
@ -721,7 +721,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
@ -730,7 +730,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
|
@ -1,6 +1,44 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
|
||||
# The code below creates version of pydantic models
|
||||
# that will work in unit tests with AnyStr as id field
|
||||
# Please note that the `id` field is assigned AFTER the model is created
|
||||
# to workaround an issue with pydantic ignoring the __eq__ method on
|
||||
# subclassed strings.
|
||||
|
||||
|
||||
def _AnyIdDocument(**kwargs: Any) -> Document:
|
||||
"""Create a document with an id field."""
|
||||
message = Document(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessageChunk(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||
"""Create a human with an any id field."""
|
||||
message = HumanMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
from tests.unit_tests.stubs import AnyStr, _AnyIdDocument
|
||||
|
||||
|
||||
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
|
||||
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
|
||||
|
||||
# Check sync version
|
||||
output = store.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", id=AnyStr())]
|
||||
assert output == [_AnyIdDocument(page_content="foo")]
|
||||
|
||||
# Check async version
|
||||
output = await store.asimilarity_search("bar", k=2)
|
||||
assert output == [
|
||||
Document(page_content="bar", id=AnyStr()),
|
||||
Document(page_content="baz", id=AnyStr()),
|
||||
_AnyIdDocument(page_content="bar"),
|
||||
_AnyIdDocument(page_content="baz"),
|
||||
]
|
||||
|
||||
|
||||
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
|
||||
# make sure we can k > docstore size
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == Document(page_content="foo", id=AnyStr())
|
||||
assert output[1] == Document(page_content="foy", id=AnyStr())
|
||||
assert output[0] == _AnyIdDocument(page_content="foo")
|
||||
assert output[1] == _AnyIdDocument(page_content="foy")
|
||||
|
||||
# Check async version
|
||||
output = await docsearch.amax_marginal_relevance_search(
|
||||
"foo", k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == Document(page_content="foo", id=AnyStr())
|
||||
assert output[1] == Document(page_content="foy", id=AnyStr())
|
||||
assert output[0] == _AnyIdDocument(page_content="foo")
|
||||
assert output[1] == _AnyIdDocument(page_content="foy")
|
||||
|
||||
|
||||
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user