mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 22:28:03 +00:00
fix(core): fix tracing for new message types in case of multiple messages (#32352)
This commit is contained in:
parent
642262f6fe
commit
740d9d3e7e
@ -1425,6 +1425,32 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
list[CallbackManagerForLLMRun]: A callback manager for each
|
list[CallbackManagerForLLMRun]: A callback manager for each
|
||||||
list of messages as an LLM run.
|
list of messages as an LLM run.
|
||||||
"""
|
"""
|
||||||
|
if messages and isinstance(messages[0], MessageV1Types):
|
||||||
|
run_id_ = run_id if run_id is not None else uuid.uuid4()
|
||||||
|
handle_event(
|
||||||
|
self.handlers,
|
||||||
|
"on_chat_model_start",
|
||||||
|
"ignore_chat_model",
|
||||||
|
serialized,
|
||||||
|
messages,
|
||||||
|
run_id=run_id_,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
CallbackManagerForLLMRun(
|
||||||
|
run_id=run_id_,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
]
|
||||||
managers = []
|
managers = []
|
||||||
for message_list in messages:
|
for message_list in messages:
|
||||||
if run_id is not None:
|
if run_id is not None:
|
||||||
@ -1936,10 +1962,51 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
async callback managers, one for each LLM Run
|
async callback managers, one for each LLM Run
|
||||||
corresponding to each inner message list.
|
corresponding to each inner message list.
|
||||||
"""
|
"""
|
||||||
|
if messages and isinstance(messages[0], MessageV1Types):
|
||||||
|
run_id_ = run_id if run_id is not None else uuid.uuid4()
|
||||||
|
inline_tasks = []
|
||||||
|
non_inline_tasks = []
|
||||||
|
for handler in self.handlers:
|
||||||
|
task = ahandle_event(
|
||||||
|
[handler],
|
||||||
|
"on_chat_model_start",
|
||||||
|
"ignore_chat_model",
|
||||||
|
serialized,
|
||||||
|
messages,
|
||||||
|
run_id=run_id_,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if handler.run_inline:
|
||||||
|
inline_tasks.append(task)
|
||||||
|
else:
|
||||||
|
non_inline_tasks.append(task)
|
||||||
|
managers = [
|
||||||
|
AsyncCallbackManagerForLLMRun(
|
||||||
|
run_id=run_id_,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# Run inline tasks sequentially
|
||||||
|
for task in inline_tasks:
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Run non-inline tasks concurrently
|
||||||
|
if non_inline_tasks:
|
||||||
|
await asyncio.gather(*non_inline_tasks)
|
||||||
|
|
||||||
|
return managers
|
||||||
inline_tasks = []
|
inline_tasks = []
|
||||||
non_inline_tasks = []
|
non_inline_tasks = []
|
||||||
managers = []
|
managers = []
|
||||||
|
|
||||||
for message_list in messages:
|
for message_list in messages:
|
||||||
if run_id is not None:
|
if run_id is not None:
|
||||||
run_id_ = run_id
|
run_id_ = run_id
|
||||||
|
@ -12,6 +12,8 @@ from freezegun import freeze_time
|
|||||||
from langchain_core.callbacks import AsyncCallbackManager
|
from langchain_core.callbacks import AsyncCallbackManager
|
||||||
from langchain_core.exceptions import TracerException
|
from langchain_core.exceptions import TracerException
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain_core.tracers.base import AsyncBaseTracer
|
from langchain_core.tracers.base import AsyncBaseTracer
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
@ -114,6 +116,38 @@ async def test_tracer_chat_model_run() -> None:
|
|||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_tracer_chat_model_run_v1() -> None:
|
||||||
|
"""Test tracer on a Chat Model run."""
|
||||||
|
tracer = FakeAsyncTracer()
|
||||||
|
manager = AsyncCallbackManager(handlers=[tracer])
|
||||||
|
messages: list[MessageV1] = [HumanMessageV1("")]
|
||||||
|
run_managers = await manager.on_chat_model_start(
|
||||||
|
serialized=SERIALIZED_CHAT, messages=messages
|
||||||
|
)
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||||
|
name="chat_model",
|
||||||
|
start_time=datetime.now(timezone.utc),
|
||||||
|
end_time=datetime.now(timezone.utc),
|
||||||
|
events=[
|
||||||
|
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||||
|
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||||
|
],
|
||||||
|
extra={},
|
||||||
|
serialized=SERIALIZED_CHAT,
|
||||||
|
inputs={"prompts": ["Human: "]},
|
||||||
|
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||||
|
error=None,
|
||||||
|
run_type="llm",
|
||||||
|
trace_id=run_managers[0].run_id,
|
||||||
|
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
|
||||||
|
)
|
||||||
|
for run_manager in run_managers:
|
||||||
|
await run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
async def test_tracer_llm_run_errors_no_start() -> None:
|
async def test_tracer_llm_run_errors_no_start() -> None:
|
||||||
"""Test tracer on an LLM run without a start."""
|
"""Test tracer on an LLM run without a start."""
|
||||||
|
@ -15,6 +15,8 @@ from langsmith import Client, traceable
|
|||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
from langchain_core.exceptions import TracerException
|
from langchain_core.exceptions import TracerException
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
from langchain_core.runnables import chain as as_runnable
|
from langchain_core.runnables import chain as as_runnable
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
@ -120,6 +122,38 @@ def test_tracer_chat_model_run() -> None:
|
|||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chat_model_run_v1() -> None:
|
||||||
|
"""Test tracer on a Chat Model run."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
manager = CallbackManager(handlers=[tracer])
|
||||||
|
messages: list[MessageV1] = [HumanMessageV1("")]
|
||||||
|
run_managers = manager.on_chat_model_start(
|
||||||
|
serialized=SERIALIZED_CHAT, messages=messages
|
||||||
|
)
|
||||||
|
compare_run = Run(
|
||||||
|
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||||
|
name="chat_model",
|
||||||
|
start_time=datetime.now(timezone.utc),
|
||||||
|
end_time=datetime.now(timezone.utc),
|
||||||
|
events=[
|
||||||
|
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||||
|
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||||
|
],
|
||||||
|
extra={},
|
||||||
|
serialized=SERIALIZED_CHAT,
|
||||||
|
inputs={"prompts": ["Human: "]},
|
||||||
|
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||||
|
error=None,
|
||||||
|
run_type="llm",
|
||||||
|
trace_id=run_managers[0].run_id,
|
||||||
|
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
|
||||||
|
)
|
||||||
|
for run_manager in run_managers:
|
||||||
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_tracer_llm_run_errors_no_start() -> None:
|
def test_tracer_llm_run_errors_no_start() -> None:
|
||||||
"""Test tracer on an LLM run without a start."""
|
"""Test tracer on an LLM run without a start."""
|
||||||
|
Loading…
Reference in New Issue
Block a user