mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
feat(integration): Add support to serialize protobufs in WandbTracer (#8914)
This PR adds serialization support for protocol bufferes in `WandbTracer`. This allows code generation chains to be visualized. Additionally, it also fixes a minor bug where the settings are not honored when a run is initialized before using the `WandbTracer` @agola11 --------- Co-authored-by: Bharat Ramanathan <ramanathan.parameshwaran@gohuddl.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5e43768f61
commit
08a8363fc6
@ -27,12 +27,21 @@ if TYPE_CHECKING:
|
||||
PRINT_WARNINGS = True
|
||||
|
||||
|
||||
def _serialize_inputs(run_inputs: dict) -> dict:
|
||||
if "input_documents" in run_inputs:
|
||||
docs = run_inputs["input_documents"]
|
||||
return {f"input_document_{i}": doc.json() for i, doc in enumerate(docs)}
|
||||
else:
|
||||
return run_inputs
|
||||
def _serialize_io(run_inputs: dict) -> dict:
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
from google.protobuf.message import Message
|
||||
|
||||
serialized_inputs = {}
|
||||
for key, value in run_inputs.items():
|
||||
if isinstance(value, Message):
|
||||
serialized_inputs[key] = MessageToJson(value)
|
||||
elif key == "input_documents":
|
||||
serialized_inputs.update(
|
||||
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
|
||||
)
|
||||
else:
|
||||
serialized_inputs[key] = value
|
||||
return serialized_inputs
|
||||
|
||||
|
||||
class RunProcessor:
|
||||
@ -117,7 +126,7 @@ class RunProcessor:
|
||||
|
||||
base_span.results = [
|
||||
self.trace_tree.Result(
|
||||
inputs=_serialize_inputs(run.inputs), outputs=run.outputs
|
||||
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
|
||||
)
|
||||
]
|
||||
base_span.child_spans = [
|
||||
@ -139,7 +148,7 @@ class RunProcessor:
|
||||
base_span = self._convert_run_to_wb_span(run)
|
||||
base_span.results = [
|
||||
self.trace_tree.Result(
|
||||
inputs=_serialize_inputs(run.inputs), outputs=run.outputs
|
||||
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
|
||||
)
|
||||
]
|
||||
base_span.child_spans = [
|
||||
@ -476,29 +485,25 @@ class WandbTracer(BaseTracer):
|
||||
If not, will start a new run with the provided run_args.
|
||||
"""
|
||||
if self._wandb.run is None:
|
||||
# Make a shallow copy of the run args, so we don't modify the original
|
||||
run_args = self._run_args or {} # type: ignore
|
||||
run_args: dict = {**run_args} # type: ignore
|
||||
|
||||
# Prefer to run in silent mode since W&B has a lot of output
|
||||
# which can be undesirable when dealing with text-based models.
|
||||
if "settings" not in run_args: # type: ignore
|
||||
run_args["settings"] = {"silent": True} # type: ignore
|
||||
|
||||
# Start the run and add the stream table
|
||||
self._wandb.init(**run_args)
|
||||
if self._wandb.run is not None:
|
||||
if should_print_url:
|
||||
run_url = self._wandb.run.settings.run_url
|
||||
self._wandb.termlog(
|
||||
f"Streaming LangChain activity to W&B at {run_url}\n"
|
||||
"`WandbTracer` is currently in beta.\n"
|
||||
"Please report any issues to "
|
||||
"https://github.com/wandb/wandb/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
if self._wandb.run is not None:
|
||||
if should_print_url:
|
||||
run_url = self._wandb.run.settings.run_url
|
||||
self._wandb.termlog(
|
||||
f"Streaming LangChain activity to W&B at {run_url}\n"
|
||||
"`WandbTracer` is currently in beta.\n"
|
||||
"Please report any issues to "
|
||||
"https://github.com/wandb/wandb/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
|
||||
self._wandb.run._label(repo="langchain")
|
||||
self._wandb.run._label(repo="langchain")
|
||||
|
||||
def _persist_run(self, run: "Run") -> None:
|
||||
"""Persist a run."""
|
||||
|
Loading…
Reference in New Issue
Block a user