mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
[core] fix: manually specifying run_id for chat models.invoke() and .ainvoke() (#20082)
This commit is contained in:
parent
ba602dc562
commit
039b7a472d
@ -158,6 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message
|
).message
|
||||||
@ -178,6 +179,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test base chat model."""
|
"""Test base chat model."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -15,7 +16,9 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.outputs.llm_result import LLMResult
|
from langchain_core.outputs.llm_result import LLMResult
|
||||||
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
from tests.unit_tests.fake.callbacks import (
|
from tests.unit_tests.fake.callbacks import (
|
||||||
BaseFakeCallbackHandler,
|
BaseFakeCallbackHandler,
|
||||||
FakeAsyncCallbackHandler,
|
FakeAsyncCallbackHandler,
|
||||||
@ -228,3 +231,44 @@ async def test_astream_implementation_uses_astream() -> None:
|
|||||||
AIMessageChunk(content="b", id=AnyStr()),
|
AIMessageChunk(content="b", id=AnyStr()),
|
||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTracer(BaseTracer):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.traced_run_ids: list = []
|
||||||
|
|
||||||
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
|
||||||
|
self.traced_run_ids.append(run.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pass_run_id() -> None:
|
||||||
|
llm = FakeListChatModel(responses=["a", "b", "c"])
|
||||||
|
cb = FakeTracer()
|
||||||
|
uid1 = uuid.uuid4()
|
||||||
|
llm.invoke("Dummy message", {"callbacks": [cb], "run_id": uid1})
|
||||||
|
assert cb.traced_run_ids == [uid1]
|
||||||
|
uid2 = uuid.uuid4()
|
||||||
|
list(llm.stream("Dummy message", {"callbacks": [cb], "run_id": uid2}))
|
||||||
|
assert cb.traced_run_ids == [uid1, uid2]
|
||||||
|
uid3 = uuid.uuid4()
|
||||||
|
llm.batch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
||||||
|
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_pass_run_id() -> None:
|
||||||
|
llm = FakeListChatModel(responses=["a", "b", "c"])
|
||||||
|
cb = FakeTracer()
|
||||||
|
uid1 = uuid.uuid4()
|
||||||
|
await llm.ainvoke("Dummy message", {"callbacks": [cb], "run_id": uid1})
|
||||||
|
assert cb.traced_run_ids == [uid1]
|
||||||
|
uid2 = uuid.uuid4()
|
||||||
|
async for _ in llm.astream("Dummy message", {"callbacks": [cb], "run_id": uid2}):
|
||||||
|
pass
|
||||||
|
assert cb.traced_run_ids == [uid1, uid2]
|
||||||
|
|
||||||
|
uid3 = uuid.uuid4()
|
||||||
|
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
||||||
|
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
||||||
|
Loading…
Reference in New Issue
Block a user