mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
feat: add model architecture back to wandb tracer (#6806)
# Description This PR adds model architecture to the `WandbTracer` from the Serialized Run kwargs. This allows visualization of the calling parameters of an Agent, LLM and Tool in Weights & Biases. 1. Safely serialize the run objects to WBTraceTree model_dict 2. Refactors the run processing logic to be more organized. - Twitter handle: @parambharat --------- Co-authored-by: Bharat Ramanathan <ramanathan.parameshwaran@gohuddl.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
adc96d60b6
commit
be29a6287d
@ -1,6 +1,7 @@
|
|||||||
"""A Tracer Implementation that records activity to Weights & Biases."""
|
"""A Tracer Implementation that records activity to Weights & Biases."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -8,6 +9,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -17,7 +19,7 @@ from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from wandb import Settings as WBSettings
|
from wandb import Settings as WBSettings
|
||||||
from wandb.sdk.data_types import trace_tree
|
from wandb.sdk.data_types.trace_tree import Span
|
||||||
from wandb.sdk.lib.paths import StrPath
|
from wandb.sdk.lib.paths import StrPath
|
||||||
from wandb.wandb_run import Run as WBRun
|
from wandb.wandb_run import Run as WBRun
|
||||||
|
|
||||||
@ -25,115 +27,350 @@ if TYPE_CHECKING:
|
|||||||
PRINT_WARNINGS = True
|
PRINT_WARNINGS = True
|
||||||
|
|
||||||
|
|
||||||
def _convert_lc_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
def _serialize_inputs(run_inputs: dict) -> dict:
|
||||||
if run.run_type == RunTypeEnum.llm:
|
|
||||||
return _convert_llm_run_to_wb_span(trace_tree, run)
|
|
||||||
elif run.run_type == RunTypeEnum.chain:
|
|
||||||
return _convert_chain_run_to_wb_span(trace_tree, run)
|
|
||||||
elif run.run_type == RunTypeEnum.tool:
|
|
||||||
return _convert_tool_run_to_wb_span(trace_tree, run)
|
|
||||||
else:
|
|
||||||
return _convert_run_to_wb_span(trace_tree, run)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
|
||||||
base_span = _convert_run_to_wb_span(trace_tree, run)
|
|
||||||
|
|
||||||
base_span.results = [
|
|
||||||
trace_tree.Result(
|
|
||||||
inputs={"prompt": prompt},
|
|
||||||
outputs={
|
|
||||||
f"gen_{g_i}": gen["text"]
|
|
||||||
for g_i, gen in enumerate(run.outputs["generations"][ndx])
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
run.outputs is not None
|
|
||||||
and len(run.outputs["generations"]) > ndx
|
|
||||||
and len(run.outputs["generations"][ndx]) > 0
|
|
||||||
)
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
|
|
||||||
]
|
|
||||||
base_span.span_kind = trace_tree.SpanKind.LLM
|
|
||||||
|
|
||||||
return base_span
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_inputs(run_inputs: dict) -> Union[dict, list]:
|
|
||||||
if "input_documents" in run_inputs:
|
if "input_documents" in run_inputs:
|
||||||
docs = run_inputs["input_documents"]
|
docs = run_inputs["input_documents"]
|
||||||
return [doc.json() for doc in docs]
|
return {f"input_document_{i}": doc.json() for i, doc in enumerate(docs)}
|
||||||
else:
|
else:
|
||||||
return run_inputs
|
return run_inputs
|
||||||
|
|
||||||
|
|
||||||
def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
class RunProcessor:
|
||||||
base_span = _convert_run_to_wb_span(trace_tree, run)
|
"""Handles the conversion of a LangChain Runs into a WBTraceTree."""
|
||||||
|
|
||||||
base_span.results = [
|
def __init__(self, wandb_module: Any, trace_module: Any):
|
||||||
trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs)
|
self.wandb = wandb_module
|
||||||
]
|
self.trace_tree = trace_module
|
||||||
base_span.child_spans = [
|
|
||||||
_convert_lc_run_to_wb_span(trace_tree, child_run)
|
|
||||||
for child_run in run.child_runs
|
|
||||||
]
|
|
||||||
base_span.span_kind = (
|
|
||||||
trace_tree.SpanKind.AGENT
|
|
||||||
if "agent" in run.serialized.get("name", "").lower()
|
|
||||||
else trace_tree.SpanKind.CHAIN
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_span
|
def process_span(self, run: Run) -> Optional["Span"]:
|
||||||
|
"""Converts a LangChain Run into a W&B Trace Span.
|
||||||
|
:param run: The LangChain Run to convert.
|
||||||
|
:return: The converted W&B Trace Span.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
span = self._convert_lc_run_to_wb_span(run)
|
||||||
|
return span
|
||||||
|
except Exception as e:
|
||||||
|
if PRINT_WARNINGS:
|
||||||
|
self.wandb.termwarn(
|
||||||
|
f"Skipping trace saving - unable to safely convert LangChain Run "
|
||||||
|
f"into W&B Trace due to: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_run_to_wb_span(self, run: Run) -> "Span":
|
||||||
|
"""Base utility to create a span from a run.
|
||||||
|
:param run: The run to convert.
|
||||||
|
:return: The converted Span.
|
||||||
|
"""
|
||||||
|
attributes = {**run.extra} if run.extra else {}
|
||||||
|
attributes["execution_order"] = run.execution_order
|
||||||
|
|
||||||
def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
return self.trace_tree.Span(
|
||||||
base_span = _convert_run_to_wb_span(trace_tree, run)
|
span_id=str(run.id) if run.id is not None else None,
|
||||||
base_span.results = [
|
name=run.name,
|
||||||
trace_tree.Result(inputs=_serialize_inputs(run.inputs), outputs=run.outputs)
|
start_time_ms=int(run.start_time.timestamp() * 1000),
|
||||||
]
|
end_time_ms=int(run.end_time.timestamp() * 1000),
|
||||||
base_span.child_spans = [
|
status_code=self.trace_tree.StatusCode.SUCCESS
|
||||||
_convert_lc_run_to_wb_span(trace_tree, child_run)
|
if run.error is None
|
||||||
for child_run in run.child_runs
|
else self.trace_tree.StatusCode.ERROR,
|
||||||
]
|
status_message=run.error,
|
||||||
base_span.span_kind = trace_tree.SpanKind.TOOL
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
return base_span
|
def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
|
||||||
|
"""Converts a LangChain LLM Run into a W&B Trace Span.
|
||||||
|
:param run: The LangChain LLM Run to convert.
|
||||||
|
:return: The converted W&B Trace Span.
|
||||||
|
"""
|
||||||
|
base_span = self._convert_run_to_wb_span(run)
|
||||||
|
if base_span.attributes is None:
|
||||||
|
base_span.attributes = {}
|
||||||
|
base_span.attributes["llm_output"] = run.outputs.get("llm_output", {})
|
||||||
|
|
||||||
|
base_span.results = [
|
||||||
|
self.trace_tree.Result(
|
||||||
|
inputs={"prompt": prompt},
|
||||||
|
outputs={
|
||||||
|
f"gen_{g_i}": gen["text"]
|
||||||
|
for g_i, gen in enumerate(run.outputs["generations"][ndx])
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
run.outputs is not None
|
||||||
|
and len(run.outputs["generations"]) > ndx
|
||||||
|
and len(run.outputs["generations"][ndx]) > 0
|
||||||
|
)
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
|
||||||
|
]
|
||||||
|
base_span.span_kind = self.trace_tree.SpanKind.LLM
|
||||||
|
|
||||||
def _convert_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
return base_span
|
||||||
attributes = {**run.extra} if run.extra else {}
|
|
||||||
attributes["execution_order"] = run.execution_order
|
|
||||||
|
|
||||||
return trace_tree.Span(
|
def _convert_chain_run_to_wb_span(self, run: Run) -> "Span":
|
||||||
span_id=str(run.id) if run.id is not None else None,
|
"""Converts a LangChain Chain Run into a W&B Trace Span.
|
||||||
name=run.serialized.get("name"),
|
:param run: The LangChain Chain Run to convert.
|
||||||
start_time_ms=int(run.start_time.timestamp() * 1000),
|
:return: The converted W&B Trace Span.
|
||||||
end_time_ms=int(run.end_time.timestamp() * 1000),
|
"""
|
||||||
status_code=trace_tree.StatusCode.SUCCESS
|
base_span = self._convert_run_to_wb_span(run)
|
||||||
if run.error is None
|
|
||||||
else trace_tree.StatusCode.ERROR,
|
|
||||||
status_message=run.error,
|
|
||||||
attributes=attributes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
base_span.results = [
|
||||||
|
self.trace_tree.Result(
|
||||||
|
inputs=_serialize_inputs(run.inputs), outputs=run.outputs
|
||||||
|
)
|
||||||
|
]
|
||||||
|
base_span.child_spans = [
|
||||||
|
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
|
||||||
|
]
|
||||||
|
base_span.span_kind = (
|
||||||
|
self.trace_tree.SpanKind.AGENT
|
||||||
|
if "agent" in run.name.lower()
|
||||||
|
else self.trace_tree.SpanKind.CHAIN
|
||||||
|
)
|
||||||
|
|
||||||
def _replace_type_with_kind(data: Any) -> Any:
|
return base_span
|
||||||
if isinstance(data, dict):
|
|
||||||
# W&B TraceTree expects "_kind" instead of "_type" since `_type` is special
|
def _convert_tool_run_to_wb_span(self, run: Run) -> "Span":
|
||||||
# in W&B.
|
"""Converts a LangChain Tool Run into a W&B Trace Span.
|
||||||
if "_type" in data:
|
:param run: The LangChain Tool Run to convert.
|
||||||
_type = data.pop("_type")
|
:return: The converted W&B Trace Span.
|
||||||
data["_kind"] = _type
|
"""
|
||||||
return {k: _replace_type_with_kind(v) for k, v in data.items()}
|
base_span = self._convert_run_to_wb_span(run)
|
||||||
elif isinstance(data, list):
|
base_span.results = [
|
||||||
return [_replace_type_with_kind(v) for v in data]
|
self.trace_tree.Result(
|
||||||
elif isinstance(data, tuple):
|
inputs=_serialize_inputs(run.inputs), outputs=run.outputs
|
||||||
return tuple(_replace_type_with_kind(v) for v in data)
|
)
|
||||||
elif isinstance(data, set):
|
]
|
||||||
return {_replace_type_with_kind(v) for v in data}
|
base_span.child_spans = [
|
||||||
else:
|
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
|
||||||
return data
|
]
|
||||||
|
base_span.span_kind = self.trace_tree.SpanKind.TOOL
|
||||||
|
|
||||||
|
return base_span
|
||||||
|
|
||||||
|
def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
|
||||||
|
"""Utility to convert any generic LangChain Run into a W&B Trace Span.
|
||||||
|
:param run: The LangChain Run to convert.
|
||||||
|
:return: The converted W&B Trace Span.
|
||||||
|
"""
|
||||||
|
if run.run_type == RunTypeEnum.llm:
|
||||||
|
return self._convert_llm_run_to_wb_span(run)
|
||||||
|
elif run.run_type == RunTypeEnum.chain:
|
||||||
|
return self._convert_chain_run_to_wb_span(run)
|
||||||
|
elif run.run_type == RunTypeEnum.tool:
|
||||||
|
return self._convert_tool_run_to_wb_span(run)
|
||||||
|
else:
|
||||||
|
return self._convert_run_to_wb_span(run)
|
||||||
|
|
||||||
|
def process_model(self, run: Run) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Utility to process a run for wandb model_dict serialization.
|
||||||
|
:param run: The run to process.
|
||||||
|
:return: The convert model_dict to pass to WBTraceTree.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(run.json())
|
||||||
|
processed = self.flatten_run(data)
|
||||||
|
keep_keys = (
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"serialized",
|
||||||
|
"inputs",
|
||||||
|
"outputs",
|
||||||
|
"parent_run_id",
|
||||||
|
"execution_order",
|
||||||
|
)
|
||||||
|
processed = self.truncate_run_iterative(processed, keep_keys=keep_keys)
|
||||||
|
exact_keys, partial_keys = ("lc", "type"), ("api_key",)
|
||||||
|
processed = self.modify_serialized_iterative(
|
||||||
|
processed, exact_keys=exact_keys, partial_keys=partial_keys
|
||||||
|
)
|
||||||
|
output = self.build_tree(processed)
|
||||||
|
return output
|
||||||
|
except Exception as e:
|
||||||
|
if PRINT_WARNINGS:
|
||||||
|
self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Utility to flatten a nest run object into a list of runs.
|
||||||
|
:param run: The base run to flatten.
|
||||||
|
:return: The flattened list of runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Utility to recursively flatten a list of child runs in a run.
|
||||||
|
:param child_runs: The list of child runs to flatten.
|
||||||
|
:return: The flattened list of runs.
|
||||||
|
"""
|
||||||
|
if child_runs is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for item in child_runs:
|
||||||
|
child_runs = item.pop("child_runs", [])
|
||||||
|
result.append(item)
|
||||||
|
result.extend(flatten(child_runs))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return flatten([run])
|
||||||
|
|
||||||
|
def truncate_run_iterative(
|
||||||
|
self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Utility to truncate a list of runs dictionaries to only keep the specified
|
||||||
|
keys in each run.
|
||||||
|
:param runs: The list of runs to truncate.
|
||||||
|
:param keep_keys: The keys to keep in each run.
|
||||||
|
:return: The truncated list of runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Utility to truncate a single run dictionary to only keep the specified
|
||||||
|
keys.
|
||||||
|
:param run: The run dictionary to truncate.
|
||||||
|
:return: The truncated run dictionary
|
||||||
|
"""
|
||||||
|
new_dict = {}
|
||||||
|
for key in run:
|
||||||
|
if key in keep_keys:
|
||||||
|
new_dict[key] = run.get(key)
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
return list(map(truncate_single, runs))
|
||||||
|
|
||||||
|
def modify_serialized_iterative(
|
||||||
|
self,
|
||||||
|
runs: List[Dict[str, Any]],
|
||||||
|
exact_keys: Tuple[str, ...] = (),
|
||||||
|
partial_keys: Tuple[str, ...] = (),
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Utility to modify the serialized field of a list of runs dictionaries.
|
||||||
|
removes any keys that match the exact_keys and any keys that contain any of the
|
||||||
|
partial_keys.
|
||||||
|
recursively moves the dictionaries under the kwargs key to the top level.
|
||||||
|
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
|
||||||
|
visualize the run. promotes the "serialized" field to the top level.
|
||||||
|
|
||||||
|
:param runs: The list of runs to modify.
|
||||||
|
:param exact_keys: A tuple of keys to remove from the serialized field.
|
||||||
|
:param partial_keys: A tuple of partial keys to remove from the serialized
|
||||||
|
field.
|
||||||
|
:return: The modified list of runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Recursively removes exact and partial keys from a dictionary.
|
||||||
|
:param obj: The dictionary to remove keys from.
|
||||||
|
:return: The modified dictionary.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
obj = {
|
||||||
|
k: v
|
||||||
|
for k, v in obj.items()
|
||||||
|
if k not in exact_keys
|
||||||
|
and not any(partial in k for partial in partial_keys)
|
||||||
|
}
|
||||||
|
for k, v in obj.items():
|
||||||
|
obj[k] = remove_exact_and_partial_keys(v)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
obj = [remove_exact_and_partial_keys(x) for x in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def handle_id_and_kwargs(
|
||||||
|
obj: Dict[str, Any], root: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Recursively handles the id and kwargs fields of a dictionary.
|
||||||
|
changes the id field to a string "_kind" field that tells WBTraceTree how
|
||||||
|
to visualize the run. recursively moves the dictionaries under the kwargs
|
||||||
|
key to the top level.
|
||||||
|
:param obj: a run dictionary with id and kwargs fields.
|
||||||
|
:param root: whether this is the root dictionary or the serialized
|
||||||
|
dictionary.
|
||||||
|
:return: The modified dictionary.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
if ("id" in obj or "name" in obj) and not root:
|
||||||
|
_kind = obj.get("id")
|
||||||
|
if not _kind:
|
||||||
|
_kind = [obj.get("name")]
|
||||||
|
obj["_kind"] = _kind[-1]
|
||||||
|
obj.pop("id", None)
|
||||||
|
obj.pop("name", None)
|
||||||
|
if "kwargs" in obj:
|
||||||
|
kwargs = obj.pop("kwargs")
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
obj[k] = v
|
||||||
|
for k, v in obj.items():
|
||||||
|
obj[k] = handle_id_and_kwargs(v)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
obj = [handle_id_and_kwargs(x) for x in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Transforms the serialized field of a run dictionary to be compatible
|
||||||
|
with WBTraceTree.
|
||||||
|
:param serialized: The serialized field of a run dictionary.
|
||||||
|
:return: The transformed serialized field.
|
||||||
|
"""
|
||||||
|
serialized = handle_id_and_kwargs(serialized, root=True)
|
||||||
|
serialized = remove_exact_and_partial_keys(serialized)
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Transforms a run dictionary to be compatible with WBTraceTree.
|
||||||
|
:param run: The run dictionary to transform.
|
||||||
|
:return: The transformed run dictionary.
|
||||||
|
"""
|
||||||
|
transformed_dict = transform_serialized(run)
|
||||||
|
|
||||||
|
serialized = transformed_dict.pop("serialized")
|
||||||
|
for k, v in serialized.items():
|
||||||
|
transformed_dict[k] = v
|
||||||
|
|
||||||
|
_kind = transformed_dict.get("_kind", None)
|
||||||
|
name = transformed_dict.pop("name", None)
|
||||||
|
exec_ord = transformed_dict.pop("execution_order", None)
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
name = _kind
|
||||||
|
|
||||||
|
output_dict = {
|
||||||
|
f"{exec_ord}_{name}": transformed_dict,
|
||||||
|
}
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
return list(map(transform_run, runs))
|
||||||
|
|
||||||
|
def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""Builds a nested dictionary from a list of runs.
|
||||||
|
:param runs: The list of runs to build the tree from.
|
||||||
|
:return: The nested dictionary representing the langchain Run in a tree
|
||||||
|
structure compatible with WBTraceTree.
|
||||||
|
"""
|
||||||
|
id_to_data = {}
|
||||||
|
child_to_parent = {}
|
||||||
|
|
||||||
|
for entity in runs:
|
||||||
|
for key, data in entity.items():
|
||||||
|
id_val = data.pop("id", None)
|
||||||
|
parent_run_id = data.pop("parent_run_id", None)
|
||||||
|
id_to_data[id_val] = {key: data}
|
||||||
|
if parent_run_id:
|
||||||
|
child_to_parent[id_val] = parent_run_id
|
||||||
|
|
||||||
|
for child_id, parent_id in child_to_parent.items():
|
||||||
|
parent_dict = id_to_data[parent_id]
|
||||||
|
parent_dict[next(iter(parent_dict))][
|
||||||
|
next(iter(id_to_data[child_id]))
|
||||||
|
] = id_to_data[child_id][next(iter(id_to_data[child_id]))]
|
||||||
|
|
||||||
|
root_dict = next(
|
||||||
|
data for id_val, data in id_to_data.items() if id_val not in child_to_parent
|
||||||
|
)
|
||||||
|
|
||||||
|
return root_dict
|
||||||
|
|
||||||
|
|
||||||
class WandbRunArgs(TypedDict):
|
class WandbRunArgs(TypedDict):
|
||||||
@ -201,12 +438,13 @@ class WandbTracer(BaseTracer):
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import wandb python package."
|
"Could not import wandb python package."
|
||||||
"Please install it with `pip install wandb`."
|
"Please install it with `pip install -U wandb`."
|
||||||
) from e
|
) from e
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
self._trace_tree = trace_tree
|
self._trace_tree = trace_tree
|
||||||
self._run_args = run_args
|
self._run_args = run_args
|
||||||
self._ensure_run(should_print_url=(wandb.run is None))
|
self._ensure_run(should_print_url=(wandb.run is None))
|
||||||
|
self.run_processor = RunProcessor(self._wandb, self._trace_tree)
|
||||||
|
|
||||||
def finish(self) -> None:
|
def finish(self) -> None:
|
||||||
"""Waits for all asynchronous processes to finish and data to upload.
|
"""Waits for all asynchronous processes to finish and data to upload.
|
||||||
@ -219,24 +457,12 @@ class WandbTracer(BaseTracer):
|
|||||||
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
||||||
self._ensure_run()
|
self._ensure_run()
|
||||||
|
|
||||||
try:
|
root_span = self.run_processor.process_span(run)
|
||||||
root_span = _convert_lc_run_to_wb_span(self._trace_tree, run)
|
model_dict = self.run_processor.process_model(run)
|
||||||
except Exception as e:
|
|
||||||
if PRINT_WARNINGS:
|
if root_span is None:
|
||||||
self._wandb.termwarn(
|
|
||||||
f"Skipping trace saving - unable to safely convert LangChain Run "
|
|
||||||
f"into W&B Trace due to: {e}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
model_dict = None
|
|
||||||
|
|
||||||
# TODO: Add something like this once we have a way to get the clean serialized
|
|
||||||
# parent dict from a run:
|
|
||||||
# serialized_parent = safely_get_span_producing_model(run)
|
|
||||||
# if serialized_parent is not None:
|
|
||||||
# model_dict = safely_convert_model_to_dict(serialized_parent)
|
|
||||||
|
|
||||||
model_trace = self._trace_tree.WBTraceTree(
|
model_trace = self._trace_tree.WBTraceTree(
|
||||||
root_span=root_span,
|
root_span=root_span,
|
||||||
model_dict=model_dict,
|
model_dict=model_dict,
|
||||||
|
Loading…
Reference in New Issue
Block a user