core: Assign missing message ids in BaseChatModel (#19863)

- This ensures ids are stable across streamed chunks
- Multiple messages in batch call get separate ids
- Also fix ids being dropped when combining message chunks

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
This commit is contained in:
Nuno Campos
2024-04-01 18:18:36 -07:00
committed by GitHub
parent e830a4e731
commit 2ae6dcdf01
24 changed files with 693 additions and 49244 deletions

View File

@@ -25,7 +25,7 @@ extended_tests:
poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
test_watch:
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
test_watch_extended:
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests

View File

@@ -35,6 +35,7 @@ from langchain.prompts import ChatPromptTemplate
from langchain.tools import tool
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import AnyStr
class FakeListLLM(LLM):
@@ -839,6 +840,7 @@ async def test_openai_agent_with_streaming() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"function_call": {
@@ -852,6 +854,7 @@ async def test_openai_agent_with_streaming() -> None:
],
"messages": [
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"function_call": {
@@ -874,6 +877,7 @@ async def test_openai_agent_with_streaming() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"function_call": {
@@ -1014,6 +1018,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@@ -1040,6 +1045,7 @@ async def test_openai_agent_tools_agent() -> None:
],
"messages": [
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@@ -1067,6 +1073,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@@ -1093,6 +1100,7 @@ async def test_openai_agent_tools_agent() -> None:
],
"messages": [
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@@ -1124,6 +1132,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@@ -1166,6 +1175,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [

View File

@@ -119,7 +119,9 @@ class GenericFakeChatModel(BaseChatModel):
content_chunks = cast(List[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
chunk = ChatGenerationChunk(
message=AIMessageChunk(id=message.id, content=token)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@@ -136,6 +138,7 @@ class GenericFakeChatModel(BaseChatModel):
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={
"function_call": {fkey: fvalue_chunk}
@@ -151,6 +154,7 @@ class GenericFakeChatModel(BaseChatModel):
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={"function_call": {fkey: fvalue}},
)
@@ -164,7 +168,7 @@ class GenericFakeChatModel(BaseChatModel):
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs={key: value}
id=message.id, content="", additional_kwargs={key: value}
)
)
if run_manager:

View File

@@ -8,6 +8,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain.callbacks.base import AsyncCallbackHandler
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import AnyStr
def test_generic_fake_chat_model_invoke() -> None:
@@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_ainvoke() -> None:
@@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_stream() -> None:
@@ -44,16 +45,16 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
# Test streaming of additional kwargs.
@@ -62,11 +63,12 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
]
message = AIMessage(
id="a1",
content="",
additional_kwargs={
"function_call": {
@@ -81,18 +83,22 @@ async def test_generic_fake_chat_model_stream() -> None:
assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id="a1",
),
AIMessageChunk(
id="a1",
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
},
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
id="a1", content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
id="a1",
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
@@ -108,6 +114,7 @@ async def test_generic_fake_chat_model_stream() -> None:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
id="a1",
content="",
additional_kwargs={
"function_call": {
@@ -128,9 +135,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
@@ -178,8 +185,8 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert tokens == ["hello", " ", "goodbye"]

View File

@@ -0,0 +1,6 @@
from typing import Any
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)