mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 18:39:57 +00:00
core: Assign missing message ids in BaseChatModel (#19863)
- This ensures ids are stable across streamed chunks - Multiple messages in batch call get separate ids - Also fix ids being dropped when combining message chunks Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17.
This commit is contained in:
@@ -25,7 +25,7 @@ extended_tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests
|
||||
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
|
||||
|
||||
test_watch_extended:
|
||||
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
|
@@ -35,6 +35,7 @@ from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.tools import tool
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
@@ -839,6 +840,7 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@@ -852,6 +854,7 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@@ -874,6 +877,7 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@@ -1014,6 +1018,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@@ -1040,6 +1045,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@@ -1067,6 +1073,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@@ -1093,6 +1100,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@@ -1124,6 +1132,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
@@ -1166,6 +1175,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
|
@@ -119,7 +119,9 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(id=message.id, content=token)
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
yield chunk
|
||||
@@ -136,6 +138,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
id=message.id,
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {fkey: fvalue_chunk}
|
||||
@@ -151,6 +154,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
id=message.id,
|
||||
content="",
|
||||
additional_kwargs={"function_call": {fkey: fvalue}},
|
||||
)
|
||||
@@ -164,7 +168,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="", additional_kwargs={key: value}
|
||||
id=message.id, content="", additional_kwargs={key: value}
|
||||
)
|
||||
)
|
||||
if run_manager:
|
||||
|
@@ -8,6 +8,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
@@ -44,16 +45,16 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
|
||||
# Test streaming of additional kwargs.
|
||||
@@ -62,11 +63,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
|
||||
]
|
||||
|
||||
message = AIMessage(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@@ -81,18 +83,22 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "move_file"}},
|
||||
id="a1",
|
||||
),
|
||||
AIMessageChunk(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'}
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
id="a1", content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
|
||||
@@ -108,6 +114,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
accumulate_chunks += chunk
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
id="a1",
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
@@ -128,9 +135,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
|
||||
|
||||
@@ -178,8 +185,8 @@ async def test_callback_handlers() -> None:
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
AIMessageChunk(content="hello", id=AnyStr()),
|
||||
AIMessageChunk(content=" ", id=AnyStr()),
|
||||
AIMessageChunk(content="goodbye", id=AnyStr()),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
|
File diff suppressed because it is too large
Load Diff
6
libs/langchain/tests/unit_tests/stubs.py
Normal file
6
libs/langchain/tests/unit_tests/stubs.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str)
|
Reference in New Issue
Block a user