mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
core[patch]: BaseTracer
helper method for Run
lookup (#14139)
I observed the same run ID extraction logic is repeated many times in `BaseTracer`. This PR creates a helper method for DRY code.
This commit is contained in:
parent
41ee3be95f
commit
bdb6ae2ed3
@ -92,6 +92,17 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
return parent_run.child_execution_order + 1
|
return parent_run.child_execution_order + 1
|
||||||
|
|
||||||
|
def _get_run(self, run_id: UUID, run_type: str | None = None) -> Run:
|
||||||
|
try:
|
||||||
|
run = self.run_map[str(run_id)]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
||||||
|
if run_type is not None and run.run_type != run_type:
|
||||||
|
raise TracerException(
|
||||||
|
f"Found {run.run_type} run at ID {run_id}, but expected {run_type} run."
|
||||||
|
)
|
||||||
|
return run
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@ -138,13 +149,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
if not run_id:
|
llm_run = self._get_run(run_id, run_type="llm")
|
||||||
raise TracerException("No run_id provided for on_llm_new_token callback.")
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
llm_run = self.run_map.get(run_id_)
|
|
||||||
if llm_run is None or llm_run.run_type != "llm":
|
|
||||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
|
||||||
event_kwargs: Dict[str, Any] = {"token": token}
|
event_kwargs: Dict[str, Any] = {"token": token}
|
||||||
if chunk:
|
if chunk:
|
||||||
event_kwargs["chunk"] = chunk
|
event_kwargs["chunk"] = chunk
|
||||||
@ -165,12 +170,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
if not run_id:
|
llm_run = self._get_run(run_id)
|
||||||
raise TracerException("No run_id provided for on_retry callback.")
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
llm_run = self.run_map.get(run_id_)
|
|
||||||
if llm_run is None:
|
|
||||||
raise TracerException("No Run found to be traced for on_retry")
|
|
||||||
retry_d: Dict[str, Any] = {
|
retry_d: Dict[str, Any] = {
|
||||||
"slept": retry_state.idle_for,
|
"slept": retry_state.idle_for,
|
||||||
"attempt": retry_state.attempt_number,
|
"attempt": retry_state.attempt_number,
|
||||||
@ -196,13 +196,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for an LLM run."""
|
"""End a trace for an LLM run."""
|
||||||
if not run_id:
|
llm_run = self._get_run(run_id, run_type="llm")
|
||||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
llm_run = self.run_map.get(run_id_)
|
|
||||||
if llm_run is None or llm_run.run_type != "llm":
|
|
||||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
|
||||||
llm_run.outputs = response.dict()
|
llm_run.outputs = response.dict()
|
||||||
for i, generations in enumerate(response.generations):
|
for i, generations in enumerate(response.generations):
|
||||||
for j, generation in enumerate(generations):
|
for j, generation in enumerate(generations):
|
||||||
@ -225,13 +219,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Handle an error for an LLM run."""
|
"""Handle an error for an LLM run."""
|
||||||
if not run_id:
|
llm_run = self._get_run(run_id, run_type="llm")
|
||||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
|
||||||
|
|
||||||
run_id_ = str(run_id)
|
|
||||||
llm_run = self.run_map.get(run_id_)
|
|
||||||
if llm_run is None or llm_run.run_type != "llm":
|
|
||||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
|
||||||
llm_run.error = repr(error)
|
llm_run.error = repr(error)
|
||||||
llm_run.end_time = datetime.utcnow()
|
llm_run.end_time = datetime.utcnow()
|
||||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||||
@ -286,12 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""End a trace for a chain run."""
|
"""End a trace for a chain run."""
|
||||||
if not run_id:
|
chain_run = self._get_run(run_id)
|
||||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
|
||||||
chain_run = self.run_map.get(str(run_id))
|
|
||||||
if chain_run is None:
|
|
||||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
|
||||||
|
|
||||||
chain_run.outputs = (
|
chain_run.outputs = (
|
||||||
outputs if isinstance(outputs, dict) else {"output": outputs}
|
outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||||
)
|
)
|
||||||
@ -312,12 +295,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Handle an error for a chain run."""
|
"""Handle an error for a chain run."""
|
||||||
if not run_id:
|
chain_run = self._get_run(run_id)
|
||||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
|
||||||
chain_run = self.run_map.get(str(run_id))
|
|
||||||
if chain_run is None:
|
|
||||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
|
||||||
|
|
||||||
chain_run.error = repr(error)
|
chain_run.error = repr(error)
|
||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||||
@ -366,12 +344,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for a tool run."""
|
"""End a trace for a tool run."""
|
||||||
if not run_id:
|
tool_run = self._get_run(run_id, run_type="tool")
|
||||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
|
||||||
tool_run = self.run_map.get(str(run_id))
|
|
||||||
if tool_run is None or tool_run.run_type != "tool":
|
|
||||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
|
||||||
|
|
||||||
tool_run.outputs = {"output": output}
|
tool_run.outputs = {"output": output}
|
||||||
tool_run.end_time = datetime.utcnow()
|
tool_run.end_time = datetime.utcnow()
|
||||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||||
@ -387,12 +360,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Handle an error for a tool run."""
|
"""Handle an error for a tool run."""
|
||||||
if not run_id:
|
tool_run = self._get_run(run_id, run_type="tool")
|
||||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
|
||||||
tool_run = self.run_map.get(str(run_id))
|
|
||||||
if tool_run is None or tool_run.run_type != "tool":
|
|
||||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
|
||||||
|
|
||||||
tool_run.error = repr(error)
|
tool_run.error = repr(error)
|
||||||
tool_run.end_time = datetime.utcnow()
|
tool_run.end_time = datetime.utcnow()
|
||||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||||
@ -445,12 +413,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Run when Retriever errors."""
|
"""Run when Retriever errors."""
|
||||||
if not run_id:
|
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
|
||||||
retrieval_run = self.run_map.get(str(run_id))
|
|
||||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
|
||||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
|
||||||
|
|
||||||
retrieval_run.error = repr(error)
|
retrieval_run.error = repr(error)
|
||||||
retrieval_run.end_time = datetime.utcnow()
|
retrieval_run.end_time = datetime.utcnow()
|
||||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||||
@ -462,11 +425,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Run when Retriever ends running."""
|
"""Run when Retriever ends running."""
|
||||||
if not run_id:
|
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
|
||||||
retrieval_run = self.run_map.get(str(run_id))
|
|
||||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
|
||||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
|
||||||
retrieval_run.outputs = {"documents": documents}
|
retrieval_run.outputs = {"documents": documents}
|
||||||
retrieval_run.end_time = datetime.utcnow()
|
retrieval_run.end_time = datetime.utcnow()
|
||||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||||
|
Loading…
Reference in New Issue
Block a user