From 08a8363fc6a7a351c78974ad98fe0db2fe39430d Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Tue, 15 Aug 2023 13:35:12 +0530 Subject: [PATCH] feat(integration): Add support to serialize protobufs in WandbTracer (#8914) This PR adds serialization support for protocol bufferes in `WandbTracer`. This allows code generation chains to be visualized. Additionally, it also fixes a minor bug where the settings are not honored when a run is initialized before using the `WandbTracer` @agola11 --------- Co-authored-by: Bharat Ramanathan Co-authored-by: Bagatur --- .../langchain/callbacks/tracers/wandb.py | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/libs/langchain/langchain/callbacks/tracers/wandb.py b/libs/langchain/langchain/callbacks/tracers/wandb.py index d747e765e3e..05f417723d6 100644 --- a/libs/langchain/langchain/callbacks/tracers/wandb.py +++ b/libs/langchain/langchain/callbacks/tracers/wandb.py @@ -27,12 +27,21 @@ if TYPE_CHECKING: PRINT_WARNINGS = True -def _serialize_inputs(run_inputs: dict) -> dict: - if "input_documents" in run_inputs: - docs = run_inputs["input_documents"] - return {f"input_document_{i}": doc.json() for i, doc in enumerate(docs)} - else: - return run_inputs +def _serialize_io(run_inputs: dict) -> dict: + from google.protobuf.json_format import MessageToJson + from google.protobuf.message import Message + + serialized_inputs = {} + for key, value in run_inputs.items(): + if isinstance(value, Message): + serialized_inputs[key] = MessageToJson(value) + elif key == "input_documents": + serialized_inputs.update( + {f"input_document_{i}": doc.json() for i, doc in enumerate(value)} + ) + else: + serialized_inputs[key] = value + return serialized_inputs class RunProcessor: @@ -117,7 +126,7 @@ class RunProcessor: base_span.results = [ self.trace_tree.Result( - inputs=_serialize_inputs(run.inputs), outputs=run.outputs + inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs) ) ] base_span.child_spans = [ @@ -139,7 +148,7 @@ class RunProcessor: base_span = self._convert_run_to_wb_span(run) base_span.results = [ self.trace_tree.Result( - inputs=_serialize_inputs(run.inputs), outputs=run.outputs + inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs) ) ] base_span.child_spans = [ @@ -476,29 +485,25 @@ class WandbTracer(BaseTracer): If not, will start a new run with the provided run_args. """ if self._wandb.run is None: - # Make a shallow copy of the run args, so we don't modify the original run_args = self._run_args or {} # type: ignore run_args: dict = {**run_args} # type: ignore - # Prefer to run in silent mode since W&B has a lot of output - # which can be undesirable when dealing with text-based models. if "settings" not in run_args: # type: ignore run_args["settings"] = {"silent": True} # type: ignore - # Start the run and add the stream table self._wandb.init(**run_args) - if self._wandb.run is not None: - if should_print_url: - run_url = self._wandb.run.settings.run_url - self._wandb.termlog( - f"Streaming LangChain activity to W&B at {run_url}\n" - "`WandbTracer` is currently in beta.\n" - "Please report any issues to " - "https://github.com/wandb/wandb/issues with the tag " - "`langchain`." - ) + if self._wandb.run is not None: + if should_print_url: + run_url = self._wandb.run.settings.run_url + self._wandb.termlog( + f"Streaming LangChain activity to W&B at {run_url}\n" + "`WandbTracer` is currently in beta.\n" + "Please report any issues to " + "https://github.com/wandb/wandb/issues with the tag " + "`langchain`." + ) - self._wandb.run._label(repo="langchain") + self._wandb.run._label(repo="langchain") def _persist_run(self, run: "Run") -> None: """Persist a run."""