From 7c03cc5ed49a6d52f963d5e322e0a7f4244c400c Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 9 Feb 2024 16:24:54 -0800 Subject: [PATCH] Support serialization when inputs/outputs contain generators (#17338) Pydantic's `dict()` function raises an error here if you pass in a generator. We have a more robust serialization function in lagnsmith that we will use instead. --- libs/core/langchain_core/tracers/langchain.py | 13 +++++++++++-- .../core/tests/unit_tests/tracers/test_langchain.py | 7 +++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 080946d3794..15cb1977c28 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -1,4 +1,5 @@ """A Tracer implementation that records to LangChain endpoint.""" + from __future__ import annotations import logging @@ -62,6 +63,14 @@ def _get_executor() -> ThreadPoolExecutor: return _EXECUTOR +def _run_to_dict(run: Run) -> dict: + return { + **run.dict(exclude={"child_runs", "inputs", "outputs"}), + "inputs": run.inputs.copy() if run.inputs is not None else None, + "outputs": run.outputs.copy() if run.outputs is not None else None, + } + + class LangChainTracer(BaseTracer): """Implementation of the SharedTracer that POSTS to the LangChain endpoint.""" @@ -150,7 +159,7 @@ class LangChainTracer(BaseTracer): def _persist_run_single(self, run: Run) -> None: """Persist a run.""" - run_dict = run.dict(exclude={"child_runs"}) + run_dict = _run_to_dict(run) run_dict["tags"] = self._get_tags(run) extra = run_dict.get("extra", {}) extra["runtime"] = get_runtime_environment() @@ -165,7 +174,7 @@ class LangChainTracer(BaseTracer): def _update_run_single(self, run: Run) -> None: """Update a run.""" try: - run_dict = run.dict() + run_dict = _run_to_dict(run) run_dict["tags"] = self._get_tags(run) self.client.update_run(run.id, **run_dict) except Exception as e: diff --git a/libs/core/tests/unit_tests/tracers/test_langchain.py b/libs/core/tests/unit_tests/tracers/test_langchain.py index 1af7fe467e6..820e547c7a9 100644 --- a/libs/core/tests/unit_tests/tracers/test_langchain.py +++ b/libs/core/tests/unit_tests/tracers/test_langchain.py @@ -22,6 +22,7 @@ def test_example_id_assignment_threadsafe() -> None: return unittest.mock.MagicMock() client = unittest.mock.MagicMock(spec=Client) + client.tracing_queue = None client.create_run = mock_create_run tracer = LangChainTracer(client=client) old_persist_run_single = tracer._persist_run_single @@ -35,6 +36,7 @@ def test_example_id_assignment_threadsafe() -> None: ): run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a") run_id_2 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7bd1e1") + run_id_3 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7cd1e1") example_id_1 = UUID("57e42c57-8c79-4d9f-8765-bf6cd3a98055") tracer.example_id = example_id_1 tracer.on_llm_start({"name": "example_1"}, ["foo"], run_id=run_id_1) @@ -44,9 +46,14 @@ def test_example_id_assignment_threadsafe() -> None: tracer.on_llm_start({"name": "example_2"}, ["foo"], run_id=run_id_2) tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_2) tracer.example_id = None + tracer.on_chain_start( + {"name": "no_examples"}, {"inputs": (i for i in range(10))}, run_id=run_id_3 + ) + tracer.on_chain_error(ValueError("Foo bar"), run_id=run_id_3) expected_example_ids = { run_id_1: example_id_1, run_id_2: example_id_2, + run_id_3: None, } tracer.wait_for_futures() assert example_ids == expected_example_ids