mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
community[patch]: fix WandbTracer
to work with new "RunV2" API (#22673)
- **Description:** This PR updates the `WandbTracer` to work with the new RunV2 API so that wandb Traces logging works correctly for new LangChain versions. Here's an example [run](https://wandb.ai/parambharat/langchain-tracing/runs/wpm99ftq) from the existing tests - **Issue:** https://github.com/wandb/wandb/issues/7762 - **Twitter handle:** @ParamBharat _If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17._
This commit is contained in:
parent
f0f4532579
commit
2b5631a6be
@ -5,6 +5,7 @@ import json
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -14,29 +15,45 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain_core.output_parsers.pydantic import PydanticBaseModel
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from wandb import Settings as WBSettings
|
from wandb import Settings as WBSettings
|
||||||
from wandb.sdk.data_types.trace_tree import Span
|
from wandb.sdk.data_types.trace_tree import Trace
|
||||||
from wandb.sdk.lib.paths import StrPath
|
from wandb.sdk.lib.paths import StrPath
|
||||||
from wandb.wandb_run import Run as WBRun
|
from wandb.wandb_run import Run as WBRun
|
||||||
|
|
||||||
|
|
||||||
PRINT_WARNINGS = True
|
PRINT_WARNINGS = True
|
||||||
|
|
||||||
|
|
||||||
def _serialize_io(run_inputs: Optional[dict]) -> dict:
|
def _serialize_io(run_io: Optional[dict]) -> dict:
|
||||||
if not run_inputs:
|
"""Utility to serialize the input and output of a run to store in wandb.
|
||||||
|
Currently, supports serializing pydantic models and protobuf messages.
|
||||||
|
|
||||||
|
:param run_io: The inputs and outputs of the run.
|
||||||
|
:return: The serialized inputs and outputs.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not run_io:
|
||||||
return {}
|
return {}
|
||||||
from google.protobuf.json_format import MessageToJson
|
from google.protobuf.json_format import MessageToJson
|
||||||
from google.protobuf.message import Message
|
from google.protobuf.message import Message
|
||||||
|
|
||||||
serialized_inputs = {}
|
serialized_inputs = {}
|
||||||
for key, value in run_inputs.items():
|
for key, value in run_io.items():
|
||||||
if isinstance(value, Message):
|
if isinstance(value, Message):
|
||||||
serialized_inputs[key] = MessageToJson(value)
|
serialized_inputs[key] = MessageToJson(value)
|
||||||
|
|
||||||
|
elif isinstance(value, PydanticBaseModel):
|
||||||
|
serialized_inputs[key] = (
|
||||||
|
value.model_dump_json()
|
||||||
|
if hasattr(value, "model_dump_json")
|
||||||
|
else value.json()
|
||||||
|
)
|
||||||
|
|
||||||
elif key == "input_documents":
|
elif key == "input_documents":
|
||||||
serialized_inputs.update(
|
serialized_inputs.update(
|
||||||
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
|
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
|
||||||
@ -46,166 +63,7 @@ def _serialize_io(run_inputs: Optional[dict]) -> dict:
|
|||||||
return serialized_inputs
|
return serialized_inputs
|
||||||
|
|
||||||
|
|
||||||
class RunProcessor:
|
def flatten_run(run: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
"""Handles the conversion of a LangChain Runs into a WBTraceTree."""
|
|
||||||
|
|
||||||
def __init__(self, wandb_module: Any, trace_module: Any):
|
|
||||||
self.wandb = wandb_module
|
|
||||||
self.trace_tree = trace_module
|
|
||||||
|
|
||||||
def process_span(self, run: Run) -> Optional["Span"]:
|
|
||||||
"""Converts a LangChain Run into a W&B Trace Span.
|
|
||||||
:param run: The LangChain Run to convert.
|
|
||||||
:return: The converted W&B Trace Span.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
span = self._convert_lc_run_to_wb_span(run)
|
|
||||||
return span
|
|
||||||
except Exception as e:
|
|
||||||
if PRINT_WARNINGS:
|
|
||||||
self.wandb.termwarn(
|
|
||||||
f"Skipping trace saving - unable to safely convert LangChain Run "
|
|
||||||
f"into W&B Trace due to: {e}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _convert_run_to_wb_span(self, run: Run) -> "Span":
|
|
||||||
"""Base utility to create a span from a run.
|
|
||||||
:param run: The run to convert.
|
|
||||||
:return: The converted Span.
|
|
||||||
"""
|
|
||||||
attributes = {**run.extra} if run.extra else {}
|
|
||||||
attributes["execution_order"] = run.execution_order # type: ignore
|
|
||||||
|
|
||||||
return self.trace_tree.Span(
|
|
||||||
span_id=str(run.id) if run.id is not None else None,
|
|
||||||
name=run.name,
|
|
||||||
start_time_ms=int(run.start_time.timestamp() * 1000),
|
|
||||||
end_time_ms=int(run.end_time.timestamp() * 1000)
|
|
||||||
if run.end_time is not None
|
|
||||||
else None,
|
|
||||||
status_code=self.trace_tree.StatusCode.SUCCESS
|
|
||||||
if run.error is None
|
|
||||||
else self.trace_tree.StatusCode.ERROR,
|
|
||||||
status_message=run.error,
|
|
||||||
attributes=attributes,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
|
|
||||||
"""Converts a LangChain LLM Run into a W&B Trace Span.
|
|
||||||
:param run: The LangChain LLM Run to convert.
|
|
||||||
:return: The converted W&B Trace Span.
|
|
||||||
"""
|
|
||||||
base_span = self._convert_run_to_wb_span(run)
|
|
||||||
if base_span.attributes is None:
|
|
||||||
base_span.attributes = {}
|
|
||||||
base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {})
|
|
||||||
|
|
||||||
base_span.results = [
|
|
||||||
self.trace_tree.Result(
|
|
||||||
inputs={"prompt": prompt},
|
|
||||||
outputs={
|
|
||||||
f"gen_{g_i}": gen["text"]
|
|
||||||
for g_i, gen in enumerate(run.outputs["generations"][ndx])
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
run.outputs is not None
|
|
||||||
and len(run.outputs["generations"]) > ndx
|
|
||||||
and len(run.outputs["generations"][ndx]) > 0
|
|
||||||
)
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
|
|
||||||
]
|
|
||||||
base_span.span_kind = self.trace_tree.SpanKind.LLM
|
|
||||||
|
|
||||||
return base_span
|
|
||||||
|
|
||||||
def _convert_chain_run_to_wb_span(self, run: Run) -> "Span":
|
|
||||||
"""Converts a LangChain Chain Run into a W&B Trace Span.
|
|
||||||
:param run: The LangChain Chain Run to convert.
|
|
||||||
:return: The converted W&B Trace Span.
|
|
||||||
"""
|
|
||||||
base_span = self._convert_run_to_wb_span(run)
|
|
||||||
|
|
||||||
base_span.results = [
|
|
||||||
self.trace_tree.Result(
|
|
||||||
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
base_span.child_spans = [
|
|
||||||
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
|
|
||||||
]
|
|
||||||
base_span.span_kind = (
|
|
||||||
self.trace_tree.SpanKind.AGENT
|
|
||||||
if "agent" in run.name.lower()
|
|
||||||
else self.trace_tree.SpanKind.CHAIN
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_span
|
|
||||||
|
|
||||||
def _convert_tool_run_to_wb_span(self, run: Run) -> "Span":
|
|
||||||
"""Converts a LangChain Tool Run into a W&B Trace Span.
|
|
||||||
:param run: The LangChain Tool Run to convert.
|
|
||||||
:return: The converted W&B Trace Span.
|
|
||||||
"""
|
|
||||||
base_span = self._convert_run_to_wb_span(run)
|
|
||||||
base_span.results = [
|
|
||||||
self.trace_tree.Result(
|
|
||||||
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
base_span.child_spans = [
|
|
||||||
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
|
|
||||||
]
|
|
||||||
base_span.span_kind = self.trace_tree.SpanKind.TOOL
|
|
||||||
|
|
||||||
return base_span
|
|
||||||
|
|
||||||
def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
|
|
||||||
"""Utility to convert any generic LangChain Run into a W&B Trace Span.
|
|
||||||
:param run: The LangChain Run to convert.
|
|
||||||
:return: The converted W&B Trace Span.
|
|
||||||
"""
|
|
||||||
if run.run_type == "llm":
|
|
||||||
return self._convert_llm_run_to_wb_span(run)
|
|
||||||
elif run.run_type == "chain":
|
|
||||||
return self._convert_chain_run_to_wb_span(run)
|
|
||||||
elif run.run_type == "tool":
|
|
||||||
return self._convert_tool_run_to_wb_span(run)
|
|
||||||
else:
|
|
||||||
return self._convert_run_to_wb_span(run)
|
|
||||||
|
|
||||||
def process_model(self, run: Run) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Utility to process a run for wandb model_dict serialization.
|
|
||||||
:param run: The run to process.
|
|
||||||
:return: The convert model_dict to pass to WBTraceTree.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = json.loads(run.json())
|
|
||||||
processed = self.flatten_run(data)
|
|
||||||
keep_keys = (
|
|
||||||
"id",
|
|
||||||
"name",
|
|
||||||
"serialized",
|
|
||||||
"inputs",
|
|
||||||
"outputs",
|
|
||||||
"parent_run_id",
|
|
||||||
"execution_order",
|
|
||||||
)
|
|
||||||
processed = self.truncate_run_iterative(processed, keep_keys=keep_keys)
|
|
||||||
exact_keys, partial_keys = ("lc", "type"), ("api_key",)
|
|
||||||
processed = self.modify_serialized_iterative(
|
|
||||||
processed, exact_keys=exact_keys, partial_keys=partial_keys
|
|
||||||
)
|
|
||||||
output = self.build_tree(processed)
|
|
||||||
return output
|
|
||||||
except Exception as e:
|
|
||||||
if PRINT_WARNINGS:
|
|
||||||
self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
||||||
"""Utility to flatten a nest run object into a list of runs.
|
"""Utility to flatten a nest run object into a list of runs.
|
||||||
:param run: The base run to flatten.
|
:param run: The base run to flatten.
|
||||||
:return: The flattened list of runs.
|
:return: The flattened list of runs.
|
||||||
@ -229,8 +87,9 @@ class RunProcessor:
|
|||||||
|
|
||||||
return flatten([run])
|
return flatten([run])
|
||||||
|
|
||||||
|
|
||||||
def truncate_run_iterative(
|
def truncate_run_iterative(
|
||||||
self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
|
runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Utility to truncate a list of runs dictionaries to only keep the specified
|
"""Utility to truncate a list of runs dictionaries to only keep the specified
|
||||||
keys in each run.
|
keys in each run.
|
||||||
@ -253,8 +112,8 @@ class RunProcessor:
|
|||||||
|
|
||||||
return list(map(truncate_single, runs))
|
return list(map(truncate_single, runs))
|
||||||
|
|
||||||
|
|
||||||
def modify_serialized_iterative(
|
def modify_serialized_iterative(
|
||||||
self,
|
|
||||||
runs: List[Dict[str, Any]],
|
runs: List[Dict[str, Any]],
|
||||||
exact_keys: Tuple[str, ...] = (),
|
exact_keys: Tuple[str, ...] = (),
|
||||||
partial_keys: Tuple[str, ...] = (),
|
partial_keys: Tuple[str, ...] = (),
|
||||||
@ -265,7 +124,6 @@ class RunProcessor:
|
|||||||
recursively moves the dictionaries under the kwargs key to the top level.
|
recursively moves the dictionaries under the kwargs key to the top level.
|
||||||
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
|
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
|
||||||
visualize the run. promotes the "serialized" field to the top level.
|
visualize the run. promotes the "serialized" field to the top level.
|
||||||
|
|
||||||
:param runs: The list of runs to modify.
|
:param runs: The list of runs to modify.
|
||||||
:param exact_keys: A tuple of keys to remove from the serialized field.
|
:param exact_keys: A tuple of keys to remove from the serialized field.
|
||||||
:param partial_keys: A tuple of partial keys to remove from the serialized
|
:param partial_keys: A tuple of partial keys to remove from the serialized
|
||||||
@ -291,9 +149,7 @@ class RunProcessor:
|
|||||||
obj = [remove_exact_and_partial_keys(x) for x in obj]
|
obj = [remove_exact_and_partial_keys(x) for x in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def handle_id_and_kwargs(
|
def handle_id_and_kwargs(obj: Dict[str, Any], root: bool = False) -> Dict[str, Any]:
|
||||||
obj: Dict[str, Any], root: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Recursively handles the id and kwargs fields of a dictionary.
|
"""Recursively handles the id and kwargs fields of a dictionary.
|
||||||
changes the id field to a string "_kind" field that tells WBTraceTree how
|
changes the id field to a string "_kind" field that tells WBTraceTree how
|
||||||
to visualize the run. recursively moves the dictionaries under the kwargs
|
to visualize the run. recursively moves the dictionaries under the kwargs
|
||||||
@ -304,10 +160,13 @@ class RunProcessor:
|
|||||||
:return: The modified dictionary.
|
:return: The modified dictionary.
|
||||||
"""
|
"""
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
|
if "data" in obj and isinstance(obj["data"], dict):
|
||||||
|
obj = obj["data"]
|
||||||
if ("id" in obj or "name" in obj) and not root:
|
if ("id" in obj or "name" in obj) and not root:
|
||||||
_kind = obj.get("id")
|
_kind = obj.get("id")
|
||||||
if not _kind:
|
if not _kind:
|
||||||
_kind = [obj.get("name")]
|
_kind = [obj.get("name")]
|
||||||
|
if isinstance(_kind, list):
|
||||||
obj["_kind"] = _kind[-1]
|
obj["_kind"] = _kind[-1]
|
||||||
obj.pop("id", None)
|
obj.pop("id", None)
|
||||||
obj.pop("name", None)
|
obj.pop("name", None)
|
||||||
@ -344,19 +203,19 @@ class RunProcessor:
|
|||||||
|
|
||||||
_kind = transformed_dict.get("_kind", None)
|
_kind = transformed_dict.get("_kind", None)
|
||||||
name = transformed_dict.pop("name", None)
|
name = transformed_dict.pop("name", None)
|
||||||
exec_ord = transformed_dict.pop("execution_order", None)
|
|
||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
name = _kind
|
name = _kind
|
||||||
|
|
||||||
output_dict = {
|
output_dict = {
|
||||||
f"{exec_ord}_{name}": transformed_dict,
|
f"{name}": transformed_dict,
|
||||||
}
|
}
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
return list(map(transform_run, runs))
|
return list(map(transform_run, runs))
|
||||||
|
|
||||||
def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
||||||
|
def build_tree(runs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""Builds a nested dictionary from a list of runs.
|
"""Builds a nested dictionary from a list of runs.
|
||||||
:param runs: The list of runs to build the tree from.
|
:param runs: The list of runs to build the tree from.
|
||||||
:return: The nested dictionary representing the langchain Run in a tree
|
:return: The nested dictionary representing the langchain Run in a tree
|
||||||
@ -425,13 +284,20 @@ class WandbTracer(BaseTracer):
|
|||||||
_run: Optional[WBRun] = None
|
_run: Optional[WBRun] = None
|
||||||
_run_args: Optional[WandbRunArgs] = None
|
_run_args: Optional[WandbRunArgs] = None
|
||||||
|
|
||||||
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
run_args: Optional[WandbRunArgs] = None,
|
||||||
|
io_serializer: Callable = _serialize_io,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
"""Initializes the WandbTracer.
|
"""Initializes the WandbTracer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
|
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
|
||||||
provided, `wandb.init()` will be called with no arguments. Please
|
provided, `wandb.init()` will be called with no arguments. Please
|
||||||
refer to the `wandb.init` for more details.
|
refer to the `wandb.init` for more details.
|
||||||
|
io_serializer: callable A function that serializes the input and outputs
|
||||||
|
of a run to store in wandb. Defaults to "_serialize_io"
|
||||||
|
|
||||||
To use W&B to monitor all LangChain activity, add this tracer like any other
|
To use W&B to monitor all LangChain activity, add this tracer like any other
|
||||||
LangChain callback:
|
LangChain callback:
|
||||||
@ -457,7 +323,7 @@ class WandbTracer(BaseTracer):
|
|||||||
self._trace_tree = trace_tree
|
self._trace_tree = trace_tree
|
||||||
self._run_args = run_args
|
self._run_args = run_args
|
||||||
self._ensure_run(should_print_url=(wandb.run is None))
|
self._ensure_run(should_print_url=(wandb.run is None))
|
||||||
self.run_processor = RunProcessor(self._wandb, self._trace_tree)
|
self._io_serializer = io_serializer
|
||||||
|
|
||||||
def finish(self) -> None:
|
def finish(self) -> None:
|
||||||
"""Waits for all asynchronous processes to finish and data to upload.
|
"""Waits for all asynchronous processes to finish and data to upload.
|
||||||
@ -466,23 +332,6 @@ class WandbTracer(BaseTracer):
|
|||||||
"""
|
"""
|
||||||
self._wandb.finish()
|
self._wandb.finish()
|
||||||
|
|
||||||
def _log_trace_from_run(self, run: Run) -> None:
|
|
||||||
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
|
||||||
self._ensure_run()
|
|
||||||
|
|
||||||
root_span = self.run_processor.process_span(run)
|
|
||||||
model_dict = self.run_processor.process_model(run)
|
|
||||||
|
|
||||||
if root_span is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
model_trace = self._trace_tree.WBTraceTree(
|
|
||||||
root_span=root_span,
|
|
||||||
model_dict=model_dict,
|
|
||||||
)
|
|
||||||
if self._wandb.run is not None:
|
|
||||||
self._wandb.run.log({"langchain_trace": model_trace})
|
|
||||||
|
|
||||||
def _ensure_run(self, should_print_url: bool = False) -> None:
|
def _ensure_run(self, should_print_url: bool = False) -> None:
|
||||||
"""Ensures an active W&B run exists.
|
"""Ensures an active W&B run exists.
|
||||||
|
|
||||||
@ -508,6 +357,133 @@ class WandbTracer(BaseTracer):
|
|||||||
|
|
||||||
self._wandb.run._label(repo="langchain")
|
self._wandb.run._label(repo="langchain")
|
||||||
|
|
||||||
|
def process_model_dict(self, run: Run) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Utility to process a run for wandb model_dict serialization.
|
||||||
|
:param run: The run to process.
|
||||||
|
:return: The convert model_dict to pass to WBTraceTree.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(run.json())
|
||||||
|
processed = flatten_run(data)
|
||||||
|
keep_keys = (
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"serialized",
|
||||||
|
"parent_run_id",
|
||||||
|
)
|
||||||
|
processed = truncate_run_iterative(processed, keep_keys=keep_keys)
|
||||||
|
exact_keys, partial_keys = (
|
||||||
|
("lc", "type", "graph"),
|
||||||
|
(
|
||||||
|
"api_key",
|
||||||
|
"input",
|
||||||
|
"output",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
processed = modify_serialized_iterative(
|
||||||
|
processed, exact_keys=exact_keys, partial_keys=partial_keys
|
||||||
|
)
|
||||||
|
output = build_tree(processed)
|
||||||
|
return output
|
||||||
|
except Exception as e:
|
||||||
|
if PRINT_WARNINGS:
|
||||||
|
self._wandb.termerror(f"WARNING: Failed to serialize model: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _log_trace_from_run(self, run: Run) -> None:
|
||||||
|
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
||||||
|
self._ensure_run()
|
||||||
|
|
||||||
|
def create_trace(
|
||||||
|
run: "Run", parent: Optional["Trace"] = None
|
||||||
|
) -> Optional["Trace"]:
|
||||||
|
"""
|
||||||
|
Create a trace for a given run and its child runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run (Run): The run for which to create a trace.
|
||||||
|
parent (Optional[Trace]): The parent trace.
|
||||||
|
If provided, the created trace is added as a child to the parent trace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Trace]: The created trace.
|
||||||
|
If an error occurs during the creation of the trace, None is returned.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs during the creation of the trace,
|
||||||
|
no exception is raised and a warning is printed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_metadata_dict(r: "Run") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Extract metadata from a given run.
|
||||||
|
|
||||||
|
This function extracts metadata from a given run
|
||||||
|
and returns it as a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r (Run): The run from which to extract metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the extracted metadata.
|
||||||
|
"""
|
||||||
|
run_dict = json.loads(r.json())
|
||||||
|
metadata_dict = run_dict.get("metadata", {})
|
||||||
|
metadata_dict["run_id"] = run_dict.get("id")
|
||||||
|
metadata_dict["parent_run_id"] = run_dict.get("parent_run_id")
|
||||||
|
metadata_dict["tags"] = run_dict.get("tags")
|
||||||
|
metadata_dict["execution_order"] = run_dict.get(
|
||||||
|
"dotted_order", ""
|
||||||
|
).count(".")
|
||||||
|
return metadata_dict
|
||||||
|
|
||||||
|
try:
|
||||||
|
if run.run_type in ["llm", "tool"]:
|
||||||
|
run_type = run.run_type
|
||||||
|
elif run.run_type == "chain":
|
||||||
|
run_type = "agent" if "agent" in run.name.lower() else "chain"
|
||||||
|
else:
|
||||||
|
run_type = None
|
||||||
|
|
||||||
|
metadata = get_metadata_dict(run)
|
||||||
|
trace_tree = self._trace_tree.Trace(
|
||||||
|
name=run.name,
|
||||||
|
kind=run_type,
|
||||||
|
status_code="error" if run.error else "success",
|
||||||
|
start_time_ms=int(run.start_time.timestamp() * 1000)
|
||||||
|
if run.start_time is not None
|
||||||
|
else None,
|
||||||
|
end_time_ms=int(run.end_time.timestamp() * 1000)
|
||||||
|
if run.end_time is not None
|
||||||
|
else None,
|
||||||
|
metadata=metadata,
|
||||||
|
inputs=self._io_serializer(run.inputs),
|
||||||
|
outputs=self._io_serializer(run.outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the run has child runs, recursively create traces for them
|
||||||
|
for child_run in run.child_runs:
|
||||||
|
create_trace(child_run, trace_tree)
|
||||||
|
|
||||||
|
if parent is None:
|
||||||
|
return trace_tree
|
||||||
|
else:
|
||||||
|
parent.add_child(trace_tree)
|
||||||
|
return parent
|
||||||
|
except Exception as e:
|
||||||
|
if PRINT_WARNINGS:
|
||||||
|
self._wandb.termwarn(
|
||||||
|
f"WARNING: Failed to serialize trace for run due to: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
run_trace = create_trace(run)
|
||||||
|
model_dict = self.process_model_dict(run)
|
||||||
|
if model_dict is not None and run_trace is not None:
|
||||||
|
run_trace._model_dict = model_dict
|
||||||
|
if self._wandb.run is not None and run_trace is not None:
|
||||||
|
run_trace.log("langchain_trace")
|
||||||
|
|
||||||
def _persist_run(self, run: "Run") -> None:
|
def _persist_run(self, run: "Run") -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
self._log_trace_from_run(run)
|
self._log_trace_from_run(run)
|
||||||
|
@ -3,17 +3,12 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from langchain._api import create_importer
|
from langchain._api import create_importer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_community.callbacks.tracers.wandb import (
|
from langchain_community.callbacks.tracers.wandb import WandbRunArgs, WandbTracer
|
||||||
RunProcessor,
|
|
||||||
WandbRunArgs,
|
|
||||||
WandbTracer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a way to dynamically look up deprecated imports.
|
# Create a way to dynamically look up deprecated imports.
|
||||||
# Used to consolidate logic for raising deprecation warnings and
|
# Used to consolidate logic for raising deprecation warnings and
|
||||||
# handling optional imports.
|
# handling optional imports.
|
||||||
DEPRECATED_LOOKUP = {
|
DEPRECATED_LOOKUP = {
|
||||||
"RunProcessor": "langchain_community.callbacks.tracers.wandb",
|
|
||||||
"WandbRunArgs": "langchain_community.callbacks.tracers.wandb",
|
"WandbRunArgs": "langchain_community.callbacks.tracers.wandb",
|
||||||
"WandbTracer": "langchain_community.callbacks.tracers.wandb",
|
"WandbTracer": "langchain_community.callbacks.tracers.wandb",
|
||||||
}
|
}
|
||||||
@ -27,7 +22,6 @@ def __getattr__(name: str) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RunProcessor",
|
|
||||||
"WandbRunArgs",
|
"WandbRunArgs",
|
||||||
"WandbTracer",
|
"WandbTracer",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user