mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
core[patch]: Update unit tests with a workaround for using AnyID in pydantic 2 (#24892)
Pydantic 2 ignores __eq__ overload for subclasses of strings.
This commit is contained in:
parent
8461934c2b
commit
5099a9c9b4
@ -7,9 +7,13 @@ from uuid import UUID
|
|||||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
|
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
from langchain_core.messages.human import HumanMessage
|
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import (
|
||||||
|
AnyStr,
|
||||||
|
_AnyIdAIMessage,
|
||||||
|
_AnyIdAIMessageChunk,
|
||||||
|
_AnyIdHumanMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_generic_fake_chat_model_invoke() -> None:
|
def test_generic_fake_chat_model_invoke() -> None:
|
||||||
@ -17,11 +21,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
|||||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
response = model.invoke("meow")
|
response = model.invoke("meow")
|
||||||
assert response == AIMessage(content="hello", id=AnyStr())
|
assert response == _AnyIdAIMessage(content="hello")
|
||||||
response = model.invoke("kitty")
|
response = model.invoke("kitty")
|
||||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
assert response == _AnyIdAIMessage(content="goodbye")
|
||||||
response = model.invoke("meow")
|
response = model.invoke("meow")
|
||||||
assert response == AIMessage(content="hello", id=AnyStr())
|
assert response == _AnyIdAIMessage(content="hello")
|
||||||
|
|
||||||
|
|
||||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||||
@ -29,11 +33,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
|||||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
response = await model.ainvoke("meow")
|
response = await model.ainvoke("meow")
|
||||||
assert response == AIMessage(content="hello", id=AnyStr())
|
assert response == _AnyIdAIMessage(content="hello")
|
||||||
response = await model.ainvoke("kitty")
|
response = await model.ainvoke("kitty")
|
||||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
assert response == _AnyIdAIMessage(content="goodbye")
|
||||||
response = await model.ainvoke("meow")
|
response = await model.ainvoke("meow")
|
||||||
assert response == AIMessage(content="hello", id=AnyStr())
|
assert response == _AnyIdAIMessage(content="hello")
|
||||||
|
|
||||||
|
|
||||||
async def test_generic_fake_chat_model_stream() -> None:
|
async def test_generic_fake_chat_model_stream() -> None:
|
||||||
@ -46,17 +50,17 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
chunks = [chunk async for chunk in model.astream("meow")]
|
chunks = [chunk async for chunk in model.astream("meow")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
AIMessageChunk(content="hello", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="hello"),
|
||||||
AIMessageChunk(content=" ", id=AnyStr()),
|
_AnyIdAIMessageChunk(content=" "),
|
||||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
chunks = [chunk for chunk in model.stream("meow")]
|
chunks = [chunk for chunk in model.stream("meow")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
AIMessageChunk(content="hello", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="hello"),
|
||||||
AIMessageChunk(content=" ", id=AnyStr()),
|
_AnyIdAIMessageChunk(content=" "),
|
||||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
@ -141,9 +145,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
|||||||
]
|
]
|
||||||
final = log_patches[-1]
|
final = log_patches[-1]
|
||||||
assert final.state["streamed_output"] == [
|
assert final.state["streamed_output"] == [
|
||||||
AIMessageChunk(content="hello", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="hello"),
|
||||||
AIMessageChunk(content=" ", id=AnyStr()),
|
_AnyIdAIMessageChunk(content=" "),
|
||||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
||||||
|
|
||||||
@ -192,9 +196,9 @@ async def test_callback_handlers() -> None:
|
|||||||
# New model
|
# New model
|
||||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||||
assert results == [
|
assert results == [
|
||||||
AIMessageChunk(content="hello", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="hello"),
|
||||||
AIMessageChunk(content=" ", id=AnyStr()),
|
_AnyIdAIMessageChunk(content=" "),
|
||||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert tokens == ["hello", " ", "goodbye"]
|
assert tokens == ["hello", " ", "goodbye"]
|
||||||
assert len({chunk.id for chunk in results}) == 1
|
assert len({chunk.id for chunk in results}) == 1
|
||||||
@ -203,8 +207,6 @@ async def test_callback_handlers() -> None:
|
|||||||
def test_chat_model_inputs() -> None:
|
def test_chat_model_inputs() -> None:
|
||||||
fake = ParrotFakeChatModel()
|
fake = ParrotFakeChatModel()
|
||||||
|
|
||||||
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
|
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
|
||||||
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
|
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
|
||||||
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
|
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")
|
||||||
content="blah", id=AnyStr()
|
|
||||||
)
|
|
||||||
|
@ -24,7 +24,7 @@ from tests.unit_tests.fake.callbacks import (
|
|||||||
FakeAsyncCallbackHandler,
|
FakeAsyncCallbackHandler,
|
||||||
FakeCallbackHandler,
|
FakeCallbackHandler,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -144,10 +144,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
|||||||
|
|
||||||
model = ModelWithGenerate()
|
model = ModelWithGenerate()
|
||||||
chunks = [chunk for chunk in model.stream("anything")]
|
chunks = [chunk for chunk in model.stream("anything")]
|
||||||
assert chunks == [AIMessage(content="hello", id=AnyStr())]
|
assert chunks == [_AnyIdAIMessage(content="hello")]
|
||||||
|
|
||||||
chunks = [chunk async for chunk in model.astream("anything")]
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert chunks == [AIMessage(content="hello", id=AnyStr())]
|
assert chunks == [_AnyIdAIMessage(content="hello")]
|
||||||
|
|
||||||
|
|
||||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||||
@ -182,15 +182,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
|||||||
model = ModelWithSyncStream()
|
model = ModelWithSyncStream()
|
||||||
chunks = [chunk for chunk in model.stream("anything")]
|
chunks = [chunk for chunk in model.stream("anything")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
AIMessageChunk(content="a", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="a"),
|
||||||
AIMessageChunk(content="b", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="b"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
assert type(model)._astream == BaseChatModel._astream
|
assert type(model)._astream == BaseChatModel._astream
|
||||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert astream_chunks == [
|
assert astream_chunks == [
|
||||||
AIMessageChunk(content="a", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="a"),
|
||||||
AIMessageChunk(content="b", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="b"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in astream_chunks}) == 1
|
assert len({chunk.id for chunk in astream_chunks}) == 1
|
||||||
|
|
||||||
@ -227,8 +227,8 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
model = ModelWithAsyncStream()
|
model = ModelWithAsyncStream()
|
||||||
chunks = [chunk async for chunk in model.astream("anything")]
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
AIMessageChunk(content="a", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="a"),
|
||||||
AIMessageChunk(content="b", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="b"),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ from langchain_core.language_models import (
|
|||||||
from langchain_core.load import dumpd, dumps
|
from langchain_core.load import dumpd, dumps
|
||||||
from langchain_core.load.load import loads
|
from langchain_core.load.load import loads
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
@ -90,7 +89,7 @@ from langchain_core.tracers import (
|
|||||||
RunLogPatch,
|
RunLogPatch,
|
||||||
)
|
)
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||||
|
|
||||||
|
|
||||||
class FakeTracer(BaseTracer):
|
class FakeTracer(BaseTracer):
|
||||||
@ -1825,7 +1824,7 @@ def test_prompt_with_chat_model(
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
assert chain.invoke(
|
assert chain.invoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == AIMessage(content="foo", id=AnyStr())
|
) == _AnyIdAIMessage(content="foo")
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -1850,8 +1849,8 @@ def test_prompt_with_chat_model(
|
|||||||
],
|
],
|
||||||
dict(callbacks=[tracer]),
|
dict(callbacks=[tracer]),
|
||||||
) == [
|
) == [
|
||||||
AIMessage(content="foo", id=AnyStr()),
|
_AnyIdAIMessage(content="foo"),
|
||||||
AIMessage(content="foo", id=AnyStr()),
|
_AnyIdAIMessage(content="foo"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == [
|
assert prompt_spy.call_args.args[1] == [
|
||||||
{"question": "What is your name?"},
|
{"question": "What is your name?"},
|
||||||
@ -1891,9 +1890,9 @@ def test_prompt_with_chat_model(
|
|||||||
assert [
|
assert [
|
||||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||||
] == [
|
] == [
|
||||||
AIMessageChunk(content="f", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="f"),
|
||||||
AIMessageChunk(content="o", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="o"),
|
||||||
AIMessageChunk(content="o", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="o"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
@ -1931,7 +1930,7 @@ async def test_prompt_with_chat_model_async(
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
assert await chain.ainvoke(
|
assert await chain.ainvoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == AIMessage(content="foo", id=AnyStr())
|
) == _AnyIdAIMessage(content="foo")
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -1956,8 +1955,8 @@ async def test_prompt_with_chat_model_async(
|
|||||||
],
|
],
|
||||||
dict(callbacks=[tracer]),
|
dict(callbacks=[tracer]),
|
||||||
) == [
|
) == [
|
||||||
AIMessage(content="foo", id=AnyStr()),
|
_AnyIdAIMessage(content="foo"),
|
||||||
AIMessage(content="foo", id=AnyStr()),
|
_AnyIdAIMessage(content="foo"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == [
|
assert prompt_spy.call_args.args[1] == [
|
||||||
{"question": "What is your name?"},
|
{"question": "What is your name?"},
|
||||||
@ -2000,9 +1999,9 @@ async def test_prompt_with_chat_model_async(
|
|||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
)
|
)
|
||||||
] == [
|
] == [
|
||||||
AIMessageChunk(content="f", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="f"),
|
||||||
AIMessageChunk(content="o", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="o"),
|
||||||
AIMessageChunk(content="o", id=AnyStr()),
|
_AnyIdAIMessageChunk(content="o"),
|
||||||
]
|
]
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
@ -2669,7 +2668,7 @@ def test_prompt_with_chat_model_and_parser(
|
|||||||
HumanMessage(content="What is your name?"),
|
HumanMessage(content="What is your name?"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
|
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
||||||
|
|
||||||
assert tracer.runs == snapshot
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
@ -2804,7 +2803,7 @@ What is your name?"""
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
|
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
|
||||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 4
|
assert len(parent_run.child_runs) == 4
|
||||||
@ -2850,7 +2849,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
assert chain.invoke(
|
assert chain.invoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == {
|
) == {
|
||||||
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
|
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
||||||
"llm": "i'm a textbot",
|
"llm": "i'm a textbot",
|
||||||
}
|
}
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
@ -3060,7 +3059,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
assert chain.invoke(
|
assert chain.invoke(
|
||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == {
|
) == {
|
||||||
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
|
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
|
||||||
"llm": "i'm a textbot",
|
"llm": "i'm a textbot",
|
||||||
"passthrough": ChatPromptValue(
|
"passthrough": ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -3269,7 +3268,7 @@ async def test_map_astream() -> None:
|
|||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": AIMessageChunk(content="i", id=AnyStr())},
|
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||||
]
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
|
@ -30,7 +30,7 @@ from langchain_core.runnables import (
|
|||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||||
|
|
||||||
|
|
||||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
||||||
@ -461,7 +461,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -470,7 +470,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -479,7 +479,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -488,7 +488,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())},
|
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {"a": "b"},
|
"metadata": {"a": "b"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -526,7 +526,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -535,7 +535,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -544,7 +544,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -560,9 +560,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"generation_info": None,
|
"generation_info": None,
|
||||||
"message": AIMessage(
|
"message": _AnyIdAIMessage(content="hello world!"),
|
||||||
content="hello world!", id=AnyStr()
|
|
||||||
),
|
|
||||||
"text": "hello world!",
|
"text": "hello world!",
|
||||||
"type": "ChatGeneration",
|
"type": "ChatGeneration",
|
||||||
}
|
}
|
||||||
@ -580,7 +578,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -589,7 +587,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -627,7 +625,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -636,7 +634,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -645,7 +643,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -661,9 +659,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"generation_info": None,
|
"generation_info": None,
|
||||||
"message": AIMessage(
|
"message": _AnyIdAIMessage(content="hello world!"),
|
||||||
content="hello world!", id=AnyStr()
|
|
||||||
),
|
|
||||||
"text": "hello world!",
|
"text": "hello world!",
|
||||||
"type": "ChatGeneration",
|
"type": "ChatGeneration",
|
||||||
}
|
}
|
||||||
@ -681,7 +677,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
@ -690,7 +686,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
|
@ -50,7 +50,7 @@ from langchain_core.runnables.schema import StreamEvent
|
|||||||
from langchain_core.runnables.utils import Input, Output
|
from langchain_core.runnables.utils import Input, Output
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_core.utils.aiter import aclosing
|
from langchain_core.utils.aiter import aclosing
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||||
|
|
||||||
|
|
||||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
||||||
@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -540,7 +540,7 @@ async def test_astream_events_from_model() -> None:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"output": AIMessageChunk(content="hello world!", id=AnyStr()),
|
"output": _AnyIdAIMessageChunk(content="hello world!"),
|
||||||
},
|
},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
@ -596,7 +596,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -605,7 +605,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -614,7 +614,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -625,7 +625,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||||
"output": AIMessage(content="hello world!", id=AnyStr()),
|
"output": _AnyIdAIMessage(content="hello world!"),
|
||||||
},
|
},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
@ -635,7 +635,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -644,7 +644,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "i_dont_stream",
|
"name": "i_dont_stream",
|
||||||
@ -682,7 +682,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -691,7 +691,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -700,7 +700,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
"name": "my_model",
|
"name": "my_model",
|
||||||
@ -711,7 +711,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||||
"output": AIMessage(content="hello world!", id=AnyStr()),
|
"output": _AnyIdAIMessage(content="hello world!"),
|
||||||
},
|
},
|
||||||
"event": "on_chat_model_end",
|
"event": "on_chat_model_end",
|
||||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||||
@ -721,7 +721,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": ["my_model"],
|
"tags": ["my_model"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
@ -730,7 +730,7 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
"data": {"output": _AnyIdAIMessage(content="hello world!")},
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ai_dont_stream",
|
"name": "ai_dont_stream",
|
||||||
|
@ -1,6 +1,44 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
class AnyStr(str):
|
class AnyStr(str):
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
return isinstance(other, str)
|
return isinstance(other, str)
|
||||||
|
|
||||||
|
|
||||||
|
# The code below creates version of pydantic models
|
||||||
|
# that will work in unit tests with AnyStr as id field
|
||||||
|
# Please note that the `id` field is assigned AFTER the model is created
|
||||||
|
# to workaround an issue with pydantic ignoring the __eq__ method on
|
||||||
|
# subclassed strings.
|
||||||
|
|
||||||
|
|
||||||
|
def _AnyIdDocument(**kwargs: Any) -> Document:
|
||||||
|
"""Create a document with an id field."""
|
||||||
|
message = Document(**kwargs)
|
||||||
|
message.id = AnyStr()
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
||||||
|
"""Create ai message with an any id field."""
|
||||||
|
message = AIMessage(**kwargs)
|
||||||
|
message.id = AnyStr()
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
||||||
|
"""Create ai message with an any id field."""
|
||||||
|
message = AIMessageChunk(**kwargs)
|
||||||
|
message.id = AnyStr()
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||||
|
"""Create a human with an any id field."""
|
||||||
|
message = HumanMessage(**kwargs)
|
||||||
|
message.id = AnyStr()
|
||||||
|
return message
|
||||||
|
@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
||||||
from langchain_core.vectorstores import InMemoryVectorStore
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import AnyStr, _AnyIdDocument
|
||||||
|
|
||||||
|
|
||||||
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
|
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
|
||||||
@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
|
|||||||
|
|
||||||
# Check sync version
|
# Check sync version
|
||||||
output = store.similarity_search("foo", k=1)
|
output = store.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo", id=AnyStr())]
|
assert output == [_AnyIdDocument(page_content="foo")]
|
||||||
|
|
||||||
# Check async version
|
# Check async version
|
||||||
output = await store.asimilarity_search("bar", k=2)
|
output = await store.asimilarity_search("bar", k=2)
|
||||||
assert output == [
|
assert output == [
|
||||||
Document(page_content="bar", id=AnyStr()),
|
_AnyIdDocument(page_content="bar"),
|
||||||
Document(page_content="baz", id=AnyStr()),
|
_AnyIdDocument(page_content="baz"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
|
|||||||
# make sure we can k > docstore size
|
# make sure we can k > docstore size
|
||||||
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
assert output[0] == Document(page_content="foo", id=AnyStr())
|
assert output[0] == _AnyIdDocument(page_content="foo")
|
||||||
assert output[1] == Document(page_content="foy", id=AnyStr())
|
assert output[1] == _AnyIdDocument(page_content="foy")
|
||||||
|
|
||||||
# Check async version
|
# Check async version
|
||||||
output = await docsearch.amax_marginal_relevance_search(
|
output = await docsearch.amax_marginal_relevance_search(
|
||||||
"foo", k=10, lambda_mult=0.1
|
"foo", k=10, lambda_mult=0.1
|
||||||
)
|
)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
assert output[0] == Document(page_content="foo", id=AnyStr())
|
assert output[0] == _AnyIdDocument(page_content="foo")
|
||||||
assert output[1] == Document(page_content="foy", id=AnyStr())
|
assert output[1] == _AnyIdDocument(page_content="foy")
|
||||||
|
|
||||||
|
|
||||||
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user