fix(core): fix tracing for new message types in case of multiple messages (#32352)

This commit is contained in:
ccurme 2025-07-31 11:47:23 -03:00 committed by GitHub
parent 642262f6fe
commit 740d9d3e7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 136 additions and 1 deletions

View File

@ -1425,6 +1425,32 @@ class CallbackManager(BaseCallbackManager):
list[CallbackManagerForLLMRun]: A callback manager for each list[CallbackManagerForLLMRun]: A callback manager for each
list of messages as an LLM run. list of messages as an LLM run.
""" """
if messages and isinstance(messages[0], MessageV1Types):
run_id_ = run_id if run_id is not None else uuid.uuid4()
handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
return [
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
]
managers = [] managers = []
for message_list in messages: for message_list in messages:
if run_id is not None: if run_id is not None:
@ -1936,10 +1962,51 @@ class AsyncCallbackManager(BaseCallbackManager):
async callback managers, one for each LLM Run async callback managers, one for each LLM Run
corresponding to each inner message list. corresponding to each inner message list.
""" """
if messages and isinstance(messages[0], MessageV1Types):
run_id_ = run_id if run_id is not None else uuid.uuid4()
inline_tasks = []
non_inline_tasks = []
for handler in self.handlers:
task = ahandle_event(
[handler],
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
if handler.run_inline:
inline_tasks.append(task)
else:
non_inline_tasks.append(task)
managers = [
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
]
# Run inline tasks sequentially
for task in inline_tasks:
await task
# Run non-inline tasks concurrently
if non_inline_tasks:
await asyncio.gather(*non_inline_tasks)
return managers
inline_tasks = [] inline_tasks = []
non_inline_tasks = [] non_inline_tasks = []
managers = [] managers = []
for message_list in messages: for message_list in messages:
if run_id is not None: if run_id is not None:
run_id_ = run_id run_id_ = run_id

View File

@ -12,6 +12,8 @@ from freezegun import freeze_time
from langchain_core.callbacks import AsyncCallbackManager from langchain_core.callbacks import AsyncCallbackManager
from langchain_core.exceptions import TracerException from langchain_core.exceptions import TracerException
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from langchain_core.tracers.base import AsyncBaseTracer from langchain_core.tracers.base import AsyncBaseTracer
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
@ -114,6 +116,38 @@ async def test_tracer_chat_model_run() -> None:
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
async def test_tracer_chat_model_run_v1() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeAsyncTracer()
manager = AsyncCallbackManager(handlers=[tracer])
messages: list[MessageV1] = [HumanMessageV1("")]
run_managers = await manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=messages
)
compare_run = Run(
id=str(run_managers[0].run_id), # type: ignore[arg-type]
name="chat_model",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED_CHAT,
inputs={"prompts": ["Human: "]},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=run_managers[0].run_id,
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
)
for run_manager in run_managers:
await run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_tracer_llm_run_errors_no_start() -> None: async def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start.""" """Test tracer on an LLM run without a start."""

View File

@ -15,6 +15,8 @@ from langsmith import Client, traceable
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.exceptions import TracerException from langchain_core.exceptions import TracerException
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
from langchain_core.messages.v1 import MessageV1
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from langchain_core.runnables import chain as as_runnable from langchain_core.runnables import chain as as_runnable
from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.base import BaseTracer
@ -120,6 +122,38 @@ def test_tracer_chat_model_run() -> None:
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run_v1() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
messages: list[MessageV1] = [HumanMessageV1("")]
run_managers = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=messages
)
compare_run = Run(
id=str(run_managers[0].run_id), # type: ignore[arg-type]
name="chat_model",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED_CHAT,
inputs={"prompts": ["Human: "]},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=run_managers[0].run_id,
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
)
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None: def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start.""" """Test tracer on an LLM run without a start."""