diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index c37aec08d21..db0301b2a1a 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -92,6 +92,17 @@ class BaseTracer(BaseCallbackHandler, ABC): return parent_run.child_execution_order + 1 + def _get_run(self, run_id: UUID, run_type: str | None = None) -> Run: + try: + run = self.run_map[str(run_id)] + except KeyError as exc: + raise TracerException(f"No indexed run ID {run_id}.") from exc + if run_type is not None and run.run_type != run_type: + raise TracerException( + f"Found {run.run_type} run at ID {run_id}, but expected {run_type} run." + ) + return run + def on_llm_start( self, serialized: Dict[str, Any], @@ -138,13 +149,7 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> Run: """Run on new LLM token. Only available when streaming is enabled.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_new_token callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run = self._get_run(run_id, run_type="llm") event_kwargs: Dict[str, Any] = {"token": token} if chunk: event_kwargs["chunk"] = chunk @@ -165,12 +170,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id: UUID, **kwargs: Any, ) -> Run: - if not run_id: - raise TracerException("No run_id provided for on_retry callback.") - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None: - raise TracerException("No Run found to be traced for on_retry") + llm_run = self._get_run(run_id) retry_d: Dict[str, Any] = { "slept": retry_state.idle_for, "attempt": retry_state.attempt_number, @@ -196,13 +196,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for an LLM run.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_end callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run = self._get_run(run_id, run_type="llm") llm_run.outputs = response.dict() for i, generations in enumerate(response.generations): for j, generation in enumerate(generations): @@ -225,13 +219,7 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> Run: """Handle an error for an LLM run.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_error callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run = self._get_run(run_id, run_type="llm") llm_run.error = repr(error) llm_run.end_time = datetime.utcnow() llm_run.events.append({"name": "error", "time": llm_run.end_time}) @@ -286,12 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> Run: """End a trace for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_end callback.") - chain_run = self.run_map.get(str(run_id)) - if chain_run is None: - raise TracerException(f"No chain Run found to be traced for {run_id}") - + chain_run = self._get_run(run_id) chain_run.outputs = ( outputs if isinstance(outputs, dict) else {"output": outputs} ) @@ -312,12 +295,7 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> Run: """Handle an error for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_error callback.") - chain_run = self.run_map.get(str(run_id)) - if chain_run is None: - raise TracerException(f"No chain Run found to be traced for {run_id}") - + chain_run = self._get_run(run_id) chain_run.error = repr(error) chain_run.end_time = datetime.utcnow() chain_run.events.append({"name": "error", "time": chain_run.end_time}) @@ -366,12 +344,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_end callback.") - tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != "tool": - raise TracerException(f"No tool Run found to be traced for {run_id}") - + tool_run = self._get_run(run_id, run_type="tool") tool_run.outputs = {"output": output} tool_run.end_time = datetime.utcnow() tool_run.events.append({"name": "end", "time": tool_run.end_time}) @@ -387,12 +360,7 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> Run: """Handle an error for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_error callback.") - tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != "tool": - raise TracerException(f"No tool Run found to be traced for {run_id}") - + tool_run = self._get_run(run_id, run_type="tool") tool_run.error = repr(error) tool_run.end_time = datetime.utcnow() tool_run.events.append({"name": "error", "time": tool_run.end_time}) @@ -445,12 +413,7 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> Run: """Run when Retriever errors.""" - if not run_id: - raise TracerException("No run_id provided for on_retriever_error callback.") - retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != "retriever": - raise TracerException(f"No retriever Run found to be traced for {run_id}") - + retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run.error = repr(error) retrieval_run.end_time = datetime.utcnow() retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) @@ -462,11 +425,7 @@ class BaseTracer(BaseCallbackHandler, ABC): self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any ) -> Run: """Run when Retriever ends running.""" - if not run_id: - raise TracerException("No run_id provided for on_retriever_end callback.") - retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != "retriever": - raise TracerException(f"No retriever Run found to be traced for {run_id}") + retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run.outputs = {"documents": documents} retrieval_run.end_time = datetime.utcnow() retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})