[core] fix: manually specifying run_id for chat models.invoke() and .ainvoke() (#20082)

This commit is contained in:
William FH 2024-04-06 16:57:32 -07:00 committed by GitHub
parent ba602dc562
commit 039b7a472d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 0 deletions

View File

@ -158,6 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs, **kwargs,
).generations[0][0], ).generations[0][0],
).message ).message
@ -178,6 +179,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs, **kwargs,
) )
return cast(ChatGeneration, llm_result.generations[0][0]).message return cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@ -1,5 +1,6 @@
"""Test base chat model.""" """Test base chat model."""
import uuid
from typing import Any, AsyncIterator, Iterator, List, Optional from typing import Any, AsyncIterator, Iterator, List, Optional
import pytest import pytest
@ -15,7 +16,9 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs.llm_result import LLMResult from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.schemas import Run
from tests.unit_tests.fake.callbacks import ( from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler, BaseFakeCallbackHandler,
FakeAsyncCallbackHandler, FakeAsyncCallbackHandler,
@ -228,3 +231,44 @@ async def test_astream_implementation_uses_astream() -> None:
AIMessageChunk(content="b", id=AnyStr()), AIMessageChunk(content="b", id=AnyStr()),
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
class FakeTracer(BaseTracer):
def __init__(self) -> None:
super().__init__()
self.traced_run_ids: list = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.traced_run_ids.append(run.id)
def test_pass_run_id() -> None:
llm = FakeListChatModel(responses=["a", "b", "c"])
cb = FakeTracer()
uid1 = uuid.uuid4()
llm.invoke("Dummy message", {"callbacks": [cb], "run_id": uid1})
assert cb.traced_run_ids == [uid1]
uid2 = uuid.uuid4()
list(llm.stream("Dummy message", {"callbacks": [cb], "run_id": uid2}))
assert cb.traced_run_ids == [uid1, uid2]
uid3 = uuid.uuid4()
llm.batch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
assert cb.traced_run_ids == [uid1, uid2, uid3]
async def test_async_pass_run_id() -> None:
llm = FakeListChatModel(responses=["a", "b", "c"])
cb = FakeTracer()
uid1 = uuid.uuid4()
await llm.ainvoke("Dummy message", {"callbacks": [cb], "run_id": uid1})
assert cb.traced_run_ids == [uid1]
uid2 = uuid.uuid4()
async for _ in llm.astream("Dummy message", {"callbacks": [cb], "run_id": uid2}):
pass
assert cb.traced_run_ids == [uid1, uid2]
uid3 = uuid.uuid4()
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
assert cb.traced_run_ids == [uid1, uid2, uid3]