mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
fix(core): fix tracing for new message types in case of multiple messages (#32352)
This commit is contained in:
parent
642262f6fe
commit
740d9d3e7e
@ -1425,6 +1425,32 @@ class CallbackManager(BaseCallbackManager):
|
||||
list[CallbackManagerForLLMRun]: A callback manager for each
|
||||
list of messages as an LLM run.
|
||||
"""
|
||||
if messages and isinstance(messages[0], MessageV1Types):
|
||||
run_id_ = run_id if run_id is not None else uuid.uuid4()
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
return [
|
||||
CallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
]
|
||||
managers = []
|
||||
for message_list in messages:
|
||||
if run_id is not None:
|
||||
@ -1936,10 +1962,51 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
async callback managers, one for each LLM Run
|
||||
corresponding to each inner message list.
|
||||
"""
|
||||
if messages and isinstance(messages[0], MessageV1Types):
|
||||
run_id_ = run_id if run_id is not None else uuid.uuid4()
|
||||
inline_tasks = []
|
||||
non_inline_tasks = []
|
||||
for handler in self.handlers:
|
||||
task = ahandle_event(
|
||||
[handler],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
if handler.run_inline:
|
||||
inline_tasks.append(task)
|
||||
else:
|
||||
non_inline_tasks.append(task)
|
||||
managers = [
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
]
|
||||
# Run inline tasks sequentially
|
||||
for task in inline_tasks:
|
||||
await task
|
||||
|
||||
# Run non-inline tasks concurrently
|
||||
if non_inline_tasks:
|
||||
await asyncio.gather(*non_inline_tasks)
|
||||
|
||||
return managers
|
||||
inline_tasks = []
|
||||
non_inline_tasks = []
|
||||
managers = []
|
||||
|
||||
for message_list in messages:
|
||||
if run_id is not None:
|
||||
run_id_ = run_id
|
||||
|
@ -12,6 +12,8 @@ from freezegun import freeze_time
|
||||
from langchain_core.callbacks import AsyncCallbackManager
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.base import AsyncBaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
@ -114,6 +116,38 @@ async def test_tracer_chat_model_run() -> None:
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_chat_model_run_v1() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
tracer = FakeAsyncTracer()
|
||||
manager = AsyncCallbackManager(handlers=[tracer])
|
||||
messages: list[MessageV1] = [HumanMessageV1("")]
|
||||
run_managers = await manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=messages
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||
name="chat_model",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs={"prompts": ["Human: "]},
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
error=None,
|
||||
run_type="llm",
|
||||
trace_id=run_managers[0].run_id,
|
||||
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
|
||||
)
|
||||
for run_manager in run_managers:
|
||||
await run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
|
@ -15,6 +15,8 @@ from langsmith import Client, traceable
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.runnables import chain as as_runnable
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
@ -120,6 +122,38 @@ def test_tracer_chat_model_run() -> None:
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run_v1() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
messages: list[MessageV1] = [HumanMessageV1("")]
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=messages
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||
name="chat_model",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs={"prompts": ["Human: "]},
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
error=None,
|
||||
run_type="llm",
|
||||
trace_id=run_managers[0].run_id,
|
||||
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
|
||||
)
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
|
Loading…
Reference in New Issue
Block a user