feat(integration): Add support to serialize protobufs in WandbTracer (#8914)

This PR adds serialization support for protocol bufferes in
`WandbTracer`. This allows code generation chains to be visualized.
Additionally, it also fixes a minor bug where the settings are not
honored when a run is initialized before using the `WandbTracer`

@agola11

---------

Co-authored-by: Bharat Ramanathan <ramanathan.parameshwaran@gohuddl.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Bharat Ramanathan 2023-08-15 13:35:12 +05:30 committed by GitHub
parent 5e43768f61
commit 08a8363fc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,12 +27,21 @@ if TYPE_CHECKING:
PRINT_WARNINGS = True PRINT_WARNINGS = True
def _serialize_inputs(run_inputs: dict) -> dict: def _serialize_io(run_inputs: dict) -> dict:
if "input_documents" in run_inputs: from google.protobuf.json_format import MessageToJson
docs = run_inputs["input_documents"] from google.protobuf.message import Message
return {f"input_document_{i}": doc.json() for i, doc in enumerate(docs)}
serialized_inputs = {}
for key, value in run_inputs.items():
if isinstance(value, Message):
serialized_inputs[key] = MessageToJson(value)
elif key == "input_documents":
serialized_inputs.update(
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
)
else: else:
return run_inputs serialized_inputs[key] = value
return serialized_inputs
class RunProcessor: class RunProcessor:
@ -117,7 +126,7 @@ class RunProcessor:
base_span.results = [ base_span.results = [
self.trace_tree.Result( self.trace_tree.Result(
inputs=_serialize_inputs(run.inputs), outputs=run.outputs inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
) )
] ]
base_span.child_spans = [ base_span.child_spans = [
@ -139,7 +148,7 @@ class RunProcessor:
base_span = self._convert_run_to_wb_span(run) base_span = self._convert_run_to_wb_span(run)
base_span.results = [ base_span.results = [
self.trace_tree.Result( self.trace_tree.Result(
inputs=_serialize_inputs(run.inputs), outputs=run.outputs inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
) )
] ]
base_span.child_spans = [ base_span.child_spans = [
@ -476,16 +485,12 @@ class WandbTracer(BaseTracer):
If not, will start a new run with the provided run_args. If not, will start a new run with the provided run_args.
""" """
if self._wandb.run is None: if self._wandb.run is None:
# Make a shallow copy of the run args, so we don't modify the original
run_args = self._run_args or {} # type: ignore run_args = self._run_args or {} # type: ignore
run_args: dict = {**run_args} # type: ignore run_args: dict = {**run_args} # type: ignore
# Prefer to run in silent mode since W&B has a lot of output
# which can be undesirable when dealing with text-based models.
if "settings" not in run_args: # type: ignore if "settings" not in run_args: # type: ignore
run_args["settings"] = {"silent": True} # type: ignore run_args["settings"] = {"silent": True} # type: ignore
# Start the run and add the stream table
self._wandb.init(**run_args) self._wandb.init(**run_args)
if self._wandb.run is not None: if self._wandb.run is not None:
if should_print_url: if should_print_url: