mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-26 13:01:55 +00:00
rfc(core): trace tools called in metadata
This commit is contained in:
parent
f241fd5c11
commit
d73159b68e
@ -543,9 +543,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||
raise err
|
||||
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
|
@ -18,6 +18,7 @@ from typing import (
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@ -281,13 +282,18 @@ class _TracerCore(ABC):
|
||||
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
llm_run.outputs = response.model_dump()
|
||||
if "ls_tools_called" not in llm_run.metadata:
|
||||
llm_run.metadata["ls_tools_called"] = []
|
||||
for i, generations in enumerate(response.generations):
|
||||
for j, generation in enumerate(generations):
|
||||
output_generation = llm_run.outputs["generations"][i][j]
|
||||
if "message" in output_generation:
|
||||
output_generation["message"] = dumpd(
|
||||
cast("ChatGeneration", generation).message
|
||||
)
|
||||
msg = cast("ChatGeneration", generation).message
|
||||
output_generation["message"] = dumpd(msg)
|
||||
if i == 0 and j == 0 and isinstance(msg, AIMessage):
|
||||
llm_run.metadata["ls_tools_called"].extend(
|
||||
[tc["name"] for tc in msg.tool_calls]
|
||||
)
|
||||
llm_run.end_time = datetime.now(timezone.utc)
|
||||
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user