diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 7040900aeef..fcc8429e582 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -158,6 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.pop("run_id", None), **kwargs, ).generations[0][0], ).message @@ -178,6 +179,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.pop("run_id", None), **kwargs, ) return cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index f65646f445a..a0d79ca9ea4 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1,5 +1,6 @@ """Test base chat model.""" +import uuid from typing import Any, AsyncIterator, Iterator, List, Optional import pytest @@ -15,7 +16,9 @@ from langchain_core.messages import ( ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult +from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs +from langchain_core.tracers.schemas import Run from tests.unit_tests.fake.callbacks import ( BaseFakeCallbackHandler, FakeAsyncCallbackHandler, @@ -228,3 +231,44 @@ async def test_astream_implementation_uses_astream() -> None: AIMessageChunk(content="b", id=AnyStr()), ] assert len({chunk.id for chunk in chunks}) == 1 + + +class FakeTracer(BaseTracer): + def __init__(self) -> None: + super().__init__() + self.traced_run_ids: list = [] + + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + + self.traced_run_ids.append(run.id) + + +def test_pass_run_id() -> None: + llm = FakeListChatModel(responses=["a", "b", "c"]) + cb = FakeTracer() + uid1 = uuid.uuid4() + llm.invoke("Dummy message", {"callbacks": [cb], "run_id": uid1}) + assert cb.traced_run_ids == [uid1] + uid2 = uuid.uuid4() + list(llm.stream("Dummy message", {"callbacks": [cb], "run_id": uid2})) + assert cb.traced_run_ids == [uid1, uid2] + uid3 = uuid.uuid4() + llm.batch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3}) + assert cb.traced_run_ids == [uid1, uid2, uid3] + + +async def test_async_pass_run_id() -> None: + llm = FakeListChatModel(responses=["a", "b", "c"]) + cb = FakeTracer() + uid1 = uuid.uuid4() + await llm.ainvoke("Dummy message", {"callbacks": [cb], "run_id": uid1}) + assert cb.traced_run_ids == [uid1] + uid2 = uuid.uuid4() + async for _ in llm.astream("Dummy message", {"callbacks": [cb], "run_id": uid2}): + pass + assert cb.traced_run_ids == [uid1, uid2] + + uid3 = uuid.uuid4() + await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3}) + assert cb.traced_run_ids == [uid1, uid2, uid3]