rfc(core): trace tools called in metadata

This commit is contained in:
Bagatur 2025-04-02 17:31:10 -07:00
parent f241fd5c11
commit d73159b68e
2 changed files with 10 additions and 6 deletions

View File

@ -543,9 +543,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
await run_manager.on_llm_error(err, response=LLMResult(generations=[])) await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err raise err
await run_manager.on_llm_end( await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
LLMResult(generations=[[generation]]),
)
# --- Custom methods --- # --- Custom methods ---

View File

@ -18,6 +18,7 @@ from typing import (
from langchain_core.exceptions import TracerException from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd from langchain_core.load import dumpd
from langchain_core.messages import AIMessage
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatGenerationChunk, ChatGenerationChunk,
@ -281,13 +282,18 @@ class _TracerCore(ABC):
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run: def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.outputs = response.model_dump() llm_run.outputs = response.model_dump()
if "ls_tools_called" not in llm_run.metadata:
llm_run.metadata["ls_tools_called"] = []
for i, generations in enumerate(response.generations): for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations): for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j] output_generation = llm_run.outputs["generations"][i][j]
if "message" in output_generation: if "message" in output_generation:
output_generation["message"] = dumpd( msg = cast("ChatGeneration", generation).message
cast("ChatGeneration", generation).message output_generation["message"] = dumpd(msg)
) if i == 0 and j == 0 and isinstance(msg, AIMessage):
llm_run.metadata["ls_tools_called"].extend(
[tc["name"] for tc in msg.tool_calls]
)
llm_run.end_time = datetime.now(timezone.utc) llm_run.end_time = datetime.now(timezone.utc)
llm_run.events.append({"name": "end", "time": llm_run.end_time}) llm_run.events.append({"name": "end", "time": llm_run.end_time})