mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
Support serialization when inputs/outputs contain generators (#17338)
Pydantic's `dict()` function raises an error here if you pass in a generator. We have a more robust serialization function in lagnsmith that we will use instead.
This commit is contained in:
parent
3a2eb6e12b
commit
7c03cc5ed4
@ -1,4 +1,5 @@
|
|||||||
"""A Tracer implementation that records to LangChain endpoint."""
|
"""A Tracer implementation that records to LangChain endpoint."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -62,6 +63,14 @@ def _get_executor() -> ThreadPoolExecutor:
|
|||||||
return _EXECUTOR
|
return _EXECUTOR
|
||||||
|
|
||||||
|
|
||||||
|
def _run_to_dict(run: Run) -> dict:
|
||||||
|
return {
|
||||||
|
**run.dict(exclude={"child_runs", "inputs", "outputs"}),
|
||||||
|
"inputs": run.inputs.copy() if run.inputs is not None else None,
|
||||||
|
"outputs": run.outputs.copy() if run.outputs is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""Implementation of the SharedTracer that POSTS to the LangChain endpoint."""
|
"""Implementation of the SharedTracer that POSTS to the LangChain endpoint."""
|
||||||
|
|
||||||
@ -150,7 +159,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
|
|
||||||
def _persist_run_single(self, run: Run) -> None:
|
def _persist_run_single(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
run_dict = run.dict(exclude={"child_runs"})
|
run_dict = _run_to_dict(run)
|
||||||
run_dict["tags"] = self._get_tags(run)
|
run_dict["tags"] = self._get_tags(run)
|
||||||
extra = run_dict.get("extra", {})
|
extra = run_dict.get("extra", {})
|
||||||
extra["runtime"] = get_runtime_environment()
|
extra["runtime"] = get_runtime_environment()
|
||||||
@ -165,7 +174,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
def _update_run_single(self, run: Run) -> None:
|
def _update_run_single(self, run: Run) -> None:
|
||||||
"""Update a run."""
|
"""Update a run."""
|
||||||
try:
|
try:
|
||||||
run_dict = run.dict()
|
run_dict = _run_to_dict(run)
|
||||||
run_dict["tags"] = self._get_tags(run)
|
run_dict["tags"] = self._get_tags(run)
|
||||||
self.client.update_run(run.id, **run_dict)
|
self.client.update_run(run.id, **run_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -22,6 +22,7 @@ def test_example_id_assignment_threadsafe() -> None:
|
|||||||
return unittest.mock.MagicMock()
|
return unittest.mock.MagicMock()
|
||||||
|
|
||||||
client = unittest.mock.MagicMock(spec=Client)
|
client = unittest.mock.MagicMock(spec=Client)
|
||||||
|
client.tracing_queue = None
|
||||||
client.create_run = mock_create_run
|
client.create_run = mock_create_run
|
||||||
tracer = LangChainTracer(client=client)
|
tracer = LangChainTracer(client=client)
|
||||||
old_persist_run_single = tracer._persist_run_single
|
old_persist_run_single = tracer._persist_run_single
|
||||||
@ -35,6 +36,7 @@ def test_example_id_assignment_threadsafe() -> None:
|
|||||||
):
|
):
|
||||||
run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||||
run_id_2 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7bd1e1")
|
run_id_2 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7bd1e1")
|
||||||
|
run_id_3 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7cd1e1")
|
||||||
example_id_1 = UUID("57e42c57-8c79-4d9f-8765-bf6cd3a98055")
|
example_id_1 = UUID("57e42c57-8c79-4d9f-8765-bf6cd3a98055")
|
||||||
tracer.example_id = example_id_1
|
tracer.example_id = example_id_1
|
||||||
tracer.on_llm_start({"name": "example_1"}, ["foo"], run_id=run_id_1)
|
tracer.on_llm_start({"name": "example_1"}, ["foo"], run_id=run_id_1)
|
||||||
@ -44,9 +46,14 @@ def test_example_id_assignment_threadsafe() -> None:
|
|||||||
tracer.on_llm_start({"name": "example_2"}, ["foo"], run_id=run_id_2)
|
tracer.on_llm_start({"name": "example_2"}, ["foo"], run_id=run_id_2)
|
||||||
tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_2)
|
tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_2)
|
||||||
tracer.example_id = None
|
tracer.example_id = None
|
||||||
|
tracer.on_chain_start(
|
||||||
|
{"name": "no_examples"}, {"inputs": (i for i in range(10))}, run_id=run_id_3
|
||||||
|
)
|
||||||
|
tracer.on_chain_error(ValueError("Foo bar"), run_id=run_id_3)
|
||||||
expected_example_ids = {
|
expected_example_ids = {
|
||||||
run_id_1: example_id_1,
|
run_id_1: example_id_1,
|
||||||
run_id_2: example_id_2,
|
run_id_2: example_id_2,
|
||||||
|
run_id_3: None,
|
||||||
}
|
}
|
||||||
tracer.wait_for_futures()
|
tracer.wait_for_futures()
|
||||||
assert example_ids == expected_example_ids
|
assert example_ids == expected_example_ids
|
||||||
|
Loading…
Reference in New Issue
Block a user