mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core: Assign missing message ids in BaseChatModel (#19863)
- This ensures ids are stable across streamed chunks - Multiple messages in batch call get separate ids - Also fix ids being dropped when combining message chunks Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17.
This commit is contained in:
parent
e830a4e731
commit
2ae6dcdf01
@ -224,6 +224,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager.on_llm_new_token(
|
||||
cast(str, chunk.message.content), chunk=chunk
|
||||
)
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
@ -294,6 +296,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
await run_manager.on_llm_new_token(
|
||||
cast(str, chunk.message.content), chunk=chunk
|
||||
)
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
@ -607,6 +611,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
run_manager.on_llm_new_token(
|
||||
cast(str, chunk.message.content), chunk=chunk
|
||||
)
|
||||
@ -622,7 +628,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for generation in result.generations:
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"run-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
@ -684,6 +692,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
await run_manager.on_llm_new_token(
|
||||
cast(str, chunk.message.content), chunk=chunk
|
||||
)
|
||||
@ -699,7 +709,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for generation in result.generations:
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"run-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
|
@ -223,7 +223,9 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=token, id=message.id)
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
yield chunk
|
||||
@ -240,6 +242,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
id=message.id,
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {fkey: fvalue_chunk}
|
||||
@ -255,6 +258,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
id=message.id,
|
||||
content="",
|
||||
additional_kwargs={"function_call": {fkey: fvalue}},
|
||||
)
|
||||
@ -268,7 +272,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="", additional_kwargs={key: value}
|
||||
id=message.id, content="", additional_kwargs={key: value}
|
||||
)
|
||||
)
|
||||
if run_manager:
|
||||
|
@ -971,4 +971,10 @@ _JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"tool",
|
||||
"ToolMessageChunk",
|
||||
),
|
||||
("langchain_core", "prompts", "image", "ImagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"image",
|
||||
"ImagePromptTemplate",
|
||||
),
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@ -34,6 +34,8 @@ class BaseMessage(Serializable):
|
||||
name: Optional[str] = None
|
||||
|
||||
id: Optional[str] = None
|
||||
"""An optional unique identifier for the message. This should ideally be
|
||||
provided by the provider/model which created the message."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
|
@ -54,6 +54,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
)
|
||||
elif isinstance(other, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
@ -65,6 +66,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
)
|
||||
else:
|
||||
return super().__add__(other)
|
||||
|
@ -54,6 +54,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@ -54,6 +54,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@ -116,7 +116,9 @@ def node_data_str(node: Node) -> str:
|
||||
return data if not data.startswith("Runnable") else data[8:]
|
||||
|
||||
|
||||
def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
def node_data_json(
|
||||
node: Node, *, with_schemas: bool = False
|
||||
) -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
from langchain_core.load.serializable import to_json_not_implemented
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
|
||||
@ -137,10 +139,17 @@ def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
},
|
||||
}
|
||||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
||||
return {
|
||||
"type": "schema",
|
||||
"data": node.data.schema(),
|
||||
}
|
||||
return (
|
||||
{
|
||||
"type": "schema",
|
||||
"data": node.data.schema(),
|
||||
}
|
||||
if with_schemas
|
||||
else {
|
||||
"type": "schema",
|
||||
"data": node_data_str(node),
|
||||
}
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"type": "unknown",
|
||||
@ -156,7 +165,7 @@ class Graph:
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
branches: Optional[Dict[str, List[Branch]]] = field(default_factory=dict)
|
||||
|
||||
def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Convert the graph to a JSON-serializable format."""
|
||||
stable_node_ids = {
|
||||
node.id: i if is_uuid(node.id) else node.id
|
||||
@ -165,7 +174,10 @@ class Graph:
|
||||
|
||||
return {
|
||||
"nodes": [
|
||||
{"id": stable_node_ids[node.id], **node_data_json(node)}
|
||||
{
|
||||
"id": stable_node_ids[node.id],
|
||||
**node_data_json(node, with_schemas=with_schemas),
|
||||
}
|
||||
for node in self.nodes.values()
|
||||
],
|
||||
"edges": [
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
@ -44,17 +45,19 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
# Test streaming of additional kwargs.
|
||||
# Relying on insertion order of the additional kwargs dict
|
||||
@ -62,9 +65,10 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
message = AIMessage(
|
||||
content="",
|
||||
@ -81,24 +85,31 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "move_file"}},
|
||||
id=AnyStr(),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'}
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"arguments": ","}},
|
||||
id=AnyStr(),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
@ -116,6 +127,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
'destination_path": "bar"\n}',
|
||||
}
|
||||
},
|
||||
id=chunks[0].id,
|
||||
)
|
||||
|
||||
|
||||
@ -128,10 +140,11 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
|
||||
|
||||
|
||||
async def test_callback_handlers() -> None:
|
||||
@ -178,16 +191,19 @@ async def test_callback_handlers() -> None:
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
assert len({chunk.id for chunk in results}) == 1
|
||||
|
||||
|
||||
def test_chat_model_inputs() -> None:
|
||||
fake = ParrotFakeChatModel()
|
||||
|
||||
assert fake.invoke("hello") == HumanMessage(content="hello")
|
||||
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
|
||||
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
|
||||
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
|
||||
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
|
||||
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
|
||||
content="blah", id=AnyStr()
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ from tests.unit_tests.fake.callbacks import (
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -140,10 +141,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
||||
|
||||
model = ModelWithGenerate()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [AIMessage(content="hello")]
|
||||
assert chunks == [AIMessage(content="hello", id=AnyStr())]
|
||||
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [AIMessage(content="hello")]
|
||||
assert chunks == [AIMessage(content="hello", id=AnyStr())]
|
||||
|
||||
|
||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
@ -178,15 +179,17 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
model = ModelWithSyncStream()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="a"),
|
||||
AIMessageChunk(content="b"),
|
||||
AIMessageChunk(content="a", id=AnyStr()),
|
||||
AIMessageChunk(content="b", id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
assert type(model)._astream == BaseChatModel._astream
|
||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert astream_chunks == [
|
||||
AIMessageChunk(content="a"),
|
||||
AIMessageChunk(content="b"),
|
||||
AIMessageChunk(content="a", id=AnyStr()),
|
||||
AIMessageChunk(content="b", id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in astream_chunks}) == 1
|
||||
|
||||
|
||||
async def test_astream_implementation_uses_astream() -> None:
|
||||
@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="a"),
|
||||
AIMessageChunk(content="b"),
|
||||
AIMessageChunk(content="a", id=AnyStr()),
|
||||
AIMessageChunk(content="b", id=AnyStr()),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -33,6 +33,55 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
sequence = prompt | fake_llm | list_parser
|
||||
graph = sequence.get_graph()
|
||||
assert graph.to_json() == {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput",
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"name": "PromptTemplate",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain_core", "language_models", "fake", "FakeListLLM"],
|
||||
"name": "FakeListLLM",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"list",
|
||||
"CommaSeparatedListOutputParser",
|
||||
],
|
||||
"name": "CommaSeparatedListOutputParser",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "schema",
|
||||
"data": "CommaSeparatedListOutputParserOutput",
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{"source": 0, "target": 1},
|
||||
{"source": 1, "target": 2},
|
||||
{"source": 3, "target": 4},
|
||||
{"source": 2, "target": 3},
|
||||
],
|
||||
}
|
||||
assert graph.to_json(with_schemas=True) == {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
@ -76,9 +125,9 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||
"id": 4,
|
||||
"type": "schema",
|
||||
"data": {
|
||||
"items": {"type": "string"},
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
],
|
||||
@ -115,7 +164,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
}
|
||||
)
|
||||
graph = sequence.get_graph()
|
||||
assert graph.to_json() == {
|
||||
assert graph.to_json(with_schemas=True) == {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
@ -484,5 +533,97 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
{"source": 2, "target": 3},
|
||||
],
|
||||
}
|
||||
assert graph.to_json() == {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput",
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
|
||||
"name": "PromptTemplate",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain_core", "language_models", "fake", "FakeListLLM"],
|
||||
"name": "FakeListLLM",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "schema",
|
||||
"data": "Parallel<as_list,as_str>Input",
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "schema",
|
||||
"data": "Parallel<as_list,as_str>Output",
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"list",
|
||||
"CommaSeparatedListOutputParser",
|
||||
],
|
||||
"name": "CommaSeparatedListOutputParser",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "schema",
|
||||
"data": "conditional_str_parser_input",
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "schema",
|
||||
"data": "conditional_str_parser_output",
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": ["langchain", "schema", "output_parser", "StrOutputParser"],
|
||||
"name": "StrOutputParser",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"output_parsers",
|
||||
"xml",
|
||||
"XMLOutputParser",
|
||||
],
|
||||
"name": "XMLOutputParser",
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{"source": 0, "target": 1},
|
||||
{"source": 1, "target": 2},
|
||||
{"source": 3, "target": 5},
|
||||
{"source": 5, "target": 4},
|
||||
{"source": 6, "target": 8},
|
||||
{"source": 8, "target": 7},
|
||||
{"source": 6, "target": 9},
|
||||
{"source": 9, "target": 7},
|
||||
{"source": 3, "target": 6},
|
||||
{"source": 7, "target": 4},
|
||||
{"source": 2, "target": 3},
|
||||
],
|
||||
}
|
||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
@ -41,6 +41,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.messages.base import BaseMessage
|
||||
from langchain_core.output_parsers import (
|
||||
BaseOutputParser,
|
||||
CommaSeparatedListOutputParser,
|
||||
@ -86,6 +87,7 @@ from langchain_core.tracers import (
|
||||
RunLogPatch,
|
||||
)
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
@ -106,6 +108,12 @@ class FakeTracer(BaseTracer):
|
||||
self.uuids_map[uuid] = next(self.uuids_generator)
|
||||
return self.uuids_map[uuid]
|
||||
|
||||
def _replace_message_id(self, maybe_message: Any) -> Any:
|
||||
if isinstance(maybe_message, BaseMessage):
|
||||
maybe_message.id = AnyStr()
|
||||
|
||||
return maybe_message
|
||||
|
||||
def _copy_run(self, run: Run) -> Run:
|
||||
if run.dotted_order:
|
||||
levels = run.dotted_order.split(".")
|
||||
@ -129,6 +137,16 @@ class FakeTracer(BaseTracer):
|
||||
"child_execution_order": None,
|
||||
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
|
||||
"dotted_order": new_dotted_order,
|
||||
"inputs": {
|
||||
k: self._replace_message_id(v) for k, v in run.inputs.items()
|
||||
}
|
||||
if isinstance(run.inputs, dict)
|
||||
else run.inputs,
|
||||
"outputs": {
|
||||
k: self._replace_message_id(v) for k, v in run.outputs.items()
|
||||
}
|
||||
if isinstance(run.outputs, dict)
|
||||
else run.outputs,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1922,7 +1940,7 @@ def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo")
|
||||
) == AIMessage(content="foo", id=AnyStr())
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -1947,8 +1965,8 @@ def test_prompt_with_chat_model(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="foo"),
|
||||
AIMessage(content="foo"),
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
@ -1988,9 +2006,9 @@ def test_prompt_with_chat_model(
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [
|
||||
AIMessageChunk(content="f"),
|
||||
AIMessageChunk(content="o"),
|
||||
AIMessageChunk(content="o"),
|
||||
AIMessageChunk(content="f", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
@ -2026,7 +2044,7 @@ async def test_prompt_with_chat_model_async(
|
||||
tracer = FakeTracer()
|
||||
assert await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo")
|
||||
) == AIMessage(content="foo", id=AnyStr())
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -2051,8 +2069,8 @@ async def test_prompt_with_chat_model_async(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="foo"),
|
||||
AIMessage(content="foo"),
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
AIMessage(content="foo", id=AnyStr()),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
@ -2095,9 +2113,9 @@ async def test_prompt_with_chat_model_async(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == [
|
||||
AIMessageChunk(content="f"),
|
||||
AIMessageChunk(content="o"),
|
||||
AIMessageChunk(content="o"),
|
||||
AIMessageChunk(content="f", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
AIMessageChunk(content="o", id=AnyStr()),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
@ -2762,7 +2780,7 @@ def test_prompt_with_chat_model_and_parser(
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
|
||||
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
@ -2895,7 +2913,7 @@ What is your name?"""
|
||||
),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 4
|
||||
@ -2941,7 +2959,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot"),
|
||||
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
|
||||
"llm": "i'm a textbot",
|
||||
}
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
@ -3151,7 +3169,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot"),
|
||||
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
|
||||
"llm": "i'm a textbot",
|
||||
"passthrough": ChatPromptValue(
|
||||
messages=[
|
||||
@ -3360,12 +3378,13 @@ async def test_map_astream() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
{"chat": AIMessageChunk(content="i", id=AnyStr())},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
assert final_value.get("chat").content == "i'm a chatbot"
|
||||
final_value["chat"].id = AnyStr()
|
||||
assert final_value.get("llm") == "i'm a textbot"
|
||||
assert final_value.get("passthrough") == prompt.invoke(
|
||||
{"question": "What is your name?"}
|
||||
|
@ -29,6 +29,7 @@ from langchain_core.runnables import (
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import tool
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
||||
@ -340,7 +341,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -348,7 +349,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -356,7 +357,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -364,7 +365,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessageChunk(content="hello world!")},
|
||||
"data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -399,7 +400,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -407,7 +408,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -415,7 +416,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -430,7 +431,9 @@ async def test_astream_events_from_model() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessage(content="hello world!"),
|
||||
"message": AIMessage(
|
||||
content="hello world!", id=AnyStr()
|
||||
),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
@ -447,7 +450,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!")},
|
||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -455,7 +458,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!")},
|
||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
@ -490,7 +493,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -498,7 +501,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -506,7 +509,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
@ -521,7 +524,9 @@ async def test_astream_events_from_model() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessage(content="hello world!"),
|
||||
"message": AIMessage(
|
||||
content="hello world!", id=AnyStr()
|
||||
),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
@ -538,7 +543,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!")},
|
||||
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
@ -546,7 +551,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!")},
|
||||
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
@ -563,7 +568,10 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
).with_config({"run_name": "my_template", "tags": ["my_template"]})
|
||||
|
||||
infinite_cycle = cycle(
|
||||
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
|
||||
[
|
||||
AIMessage(content="hello world!", id="ai1"),
|
||||
AIMessage(content="goodbye world!", id="ai2"),
|
||||
]
|
||||
)
|
||||
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
||||
model = (
|
||||
@ -640,7 +648,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"name": "my_model",
|
||||
@ -648,7 +656,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
@ -656,7 +664,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"name": "my_model",
|
||||
@ -664,7 +672,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
@ -672,7 +680,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"name": "my_model",
|
||||
@ -680,7 +688,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
@ -702,7 +710,9 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessageChunk(content="hello world!"),
|
||||
"message": AIMessageChunk(
|
||||
content="hello world!", id="ai1"
|
||||
),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGenerationChunk",
|
||||
}
|
||||
@ -719,7 +729,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessageChunk(content="hello world!")},
|
||||
"data": {"output": AIMessageChunk(content="hello world!", id="ai1")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
@ -1332,8 +1342,8 @@ async def test_runnable_each() -> None:
|
||||
|
||||
async def test_events_astream_config() -> None:
|
||||
"""Test that astream events support accepting config"""
|
||||
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
||||
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
|
||||
infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")])
|
||||
good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
|
||||
messages=ConfigurableField(
|
||||
id="messages",
|
||||
@ -1343,7 +1353,7 @@ async def test_events_astream_config() -> None:
|
||||
)
|
||||
|
||||
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
|
||||
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
|
||||
assert model_02.invoke("hello") == AIMessage(content="Goodbye world", id="ai2")
|
||||
|
||||
events = await _collect_events(model_02.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
@ -1356,7 +1366,7 @@ async def test_events_astream_config() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="Goodbye")},
|
||||
"data": {"chunk": AIMessageChunk(content="Goodbye", id="ai2")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
@ -1364,7 +1374,7 @@ async def test_events_astream_config() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id="ai2")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
@ -1372,7 +1382,7 @@ async def test_events_astream_config() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world")},
|
||||
"data": {"chunk": AIMessageChunk(content="world", id="ai2")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
@ -1380,7 +1390,7 @@ async def test_events_astream_config() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessageChunk(content="Goodbye world")},
|
||||
"data": {"output": AIMessageChunk(content="Goodbye world", id="ai2")},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
@ -1418,7 +1428,9 @@ async def test_runnable_with_message_history() -> None:
|
||||
store[session_id] = []
|
||||
return InMemoryHistory(messages=store[session_id])
|
||||
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
|
||||
infinite_cycle = cycle(
|
||||
[AIMessage(content="hello", id="ai3"), AIMessage(content="world", id="ai4")]
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@ -1441,7 +1453,10 @@ async def test_runnable_with_message_history() -> None:
|
||||
).ainvoke({"question": "hello"})
|
||||
|
||||
assert store == {
|
||||
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
|
||||
"session-123": [
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="hello", id="ai3"),
|
||||
]
|
||||
}
|
||||
|
||||
with_message_history.with_config(
|
||||
@ -1450,8 +1465,8 @@ async def test_runnable_with_message_history() -> None:
|
||||
assert store == {
|
||||
"session-123": [
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="hello"),
|
||||
AIMessage(content="hello", id="ai3"),
|
||||
HumanMessage(content="meow"),
|
||||
AIMessage(content="world"),
|
||||
AIMessage(content="world", id="ai4"),
|
||||
]
|
||||
}
|
||||
|
6
libs/core/tests/unit_tests/stubs.py
Normal file
6
libs/core/tests/unit_tests/stubs.py
Normal file
@ -0,0 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str)
|
@ -23,15 +23,16 @@ from langchain_core.messages import (
|
||||
|
||||
|
||||
def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="I am") + AIMessageChunk(
|
||||
assert AIMessageChunk(content="I am", id="ai3") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
content="I am indeed.", id="ai3"
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk"
|
||||
|
||||
assert (
|
||||
AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.")
|
||||
== AIMessageChunk(content="I am indeed.")
|
||||
AIMessageChunk(content="I am", id="ai2")
|
||||
+ HumanMessageChunk(content=" indeed.", id="human1")
|
||||
== AIMessageChunk(content="I am indeed.", id="ai2")
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
|
||||
|
||||
assert (
|
||||
@ -69,10 +70,10 @@ def test_message_chunks() -> None:
|
||||
|
||||
|
||||
def test_chat_message_chunks() -> None:
|
||||
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
|
||||
role="User", content=" indeed."
|
||||
) == ChatMessageChunk(
|
||||
role="User", content="I am indeed."
|
||||
id="ai4", role="User", content="I am indeed."
|
||||
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
@ -94,10 +95,10 @@ def test_chat_message_chunks() -> None:
|
||||
|
||||
|
||||
def test_function_message_chunks() -> None:
|
||||
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="hello", content=" indeed."
|
||||
) == FunctionMessageChunk(
|
||||
name="hello", content="I am indeed."
|
||||
assert FunctionMessageChunk(
|
||||
name="hello", content="I am", id="ai5"
|
||||
) + FunctionMessageChunk(name="hello", content=" indeed.") == FunctionMessageChunk(
|
||||
id="ai5", name="hello", content="I am indeed."
|
||||
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -25,7 +25,7 @@ extended_tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests
|
||||
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
|
||||
|
||||
test_watch_extended:
|
||||
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
|
@ -35,6 +35,7 @@ from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.tools import tool
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
@ -839,6 +840,7 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@ -852,6 +854,7 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@ -874,6 +877,7 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@ -1014,6 +1018,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@ -1040,6 +1045,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@ -1067,6 +1073,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@ -1093,6 +1100,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@ -1124,6 +1132,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@ -1166,6 +1175,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
|
@ -119,7 +119,9 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(id=message.id, content=token)
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
yield chunk
|
||||
@ -136,6 +138,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
id=message.id,
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {fkey: fvalue_chunk}
|
||||
@ -151,6 +154,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
id=message.id,
|
||||
content="",
|
||||
additional_kwargs={"function_call": {fkey: fvalue}},
|
||||
)
|
||||
@ -164,7 +168,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="", additional_kwargs={key: value}
|
||||
id=message.id, content="", additional_kwargs={key: value}
|
||||
)
|
||||
)
|
||||
if run_manager:
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
@ -44,16 +45,16 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
|
||||
# Test streaming of additional kwargs.
|
||||
@ -62,11 +63,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
|
||||
]
|
||||
|
||||
message = AIMessage(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@ -81,18 +83,22 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "move_file"}},
|
||||
id="a1",
|
||||
),
|
||||
AIMessageChunk(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'}
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
id="a1", content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
|
||||
@ -108,6 +114,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
accumulate_chunks += chunk
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@ -128,9 +135,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
|
||||
|
||||
@ -178,8 +185,8 @@ async def test_callback_handlers() -> None:
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
|
File diff suppressed because it is too large
Load Diff
6
libs/langchain/tests/unit_tests/stubs.py
Normal file
6
libs/langchain/tests/unit_tests/stubs.py
Normal file
@ -0,0 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str)
|
Loading…
Reference in New Issue
Block a user