mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Fix Pickle Error (#12141)
If non-pickleable objects (like locks) get passed to the tracing callback, they'll fail in the deepcopy. Fallback to a shallow copy in these instances .
This commit is contained in:
parent
95a1b598fe
commit
4f23aa677a
@ -411,11 +411,12 @@ def _handle_event(
|
|||||||
handler_name = handler.__class__.__name__
|
handler_name = handler.__class__.__name__
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"NotImplementedError in {handler_name}.{event_name}"
|
f"NotImplementedError in {handler_name}.{event_name}"
|
||||||
f" callback: {e}"
|
f" callback: {repr(e)}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
|
f"Error in {handler.__class__.__name__}.{event_name} callback:"
|
||||||
|
f" {repr(e)}"
|
||||||
)
|
)
|
||||||
if handler.raise_error:
|
if handler.raise_error:
|
||||||
raise e
|
raise e
|
||||||
@ -496,11 +497,12 @@ async def _ahandle_event_for_handler(
|
|||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"NotImplementedError in {handler.__class__.__name__}.{event_name}"
|
f"NotImplementedError in {handler.__class__.__name__}.{event_name}"
|
||||||
f" callback: {e}"
|
f" callback: {repr(e)}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
|
f"Error in {handler.__class__.__name__}.{event_name} callback:"
|
||||||
|
f" {repr(e)}"
|
||||||
)
|
)
|
||||||
if handler.raise_error:
|
if handler.raise_error:
|
||||||
raise e
|
raise e
|
||||||
|
@ -64,6 +64,16 @@ def _get_executor() -> ThreadPoolExecutor:
|
|||||||
return _EXECUTOR
|
return _EXECUTOR
|
||||||
|
|
||||||
|
|
||||||
|
def _copy(run: Run) -> Run:
|
||||||
|
"""Copy a run."""
|
||||||
|
try:
|
||||||
|
return run.copy(deep=True)
|
||||||
|
except TypeError:
|
||||||
|
# Fallback in case the object contains a lock or other
|
||||||
|
# non-pickleable object
|
||||||
|
return run.copy()
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
@ -192,63 +202,63 @@ class LangChainTracer(BaseTracer):
|
|||||||
"""Persist an LLM run."""
|
"""Persist an LLM run."""
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
self._submit(self._persist_run_single, run.copy(deep=True))
|
self._submit(self._persist_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_chat_model_start(self, run: Run) -> None:
|
def _on_chat_model_start(self, run: Run) -> None:
|
||||||
"""Persist an LLM run."""
|
"""Persist an LLM run."""
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
self._submit(self._persist_run_single, run.copy(deep=True))
|
self._submit(self._persist_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_llm_end(self, run: Run) -> None:
|
def _on_llm_end(self, run: Run) -> None:
|
||||||
"""Process the LLM Run."""
|
"""Process the LLM Run."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_llm_error(self, run: Run) -> None:
|
def _on_llm_error(self, run: Run) -> None:
|
||||||
"""Process the LLM Run upon error."""
|
"""Process the LLM Run upon error."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_chain_start(self, run: Run) -> None:
|
def _on_chain_start(self, run: Run) -> None:
|
||||||
"""Process the Chain Run upon start."""
|
"""Process the Chain Run upon start."""
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
self._submit(self._persist_run_single, run.copy(deep=True))
|
self._submit(self._persist_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_chain_end(self, run: Run) -> None:
|
def _on_chain_end(self, run: Run) -> None:
|
||||||
"""Process the Chain Run."""
|
"""Process the Chain Run."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_chain_error(self, run: Run) -> None:
|
def _on_chain_error(self, run: Run) -> None:
|
||||||
"""Process the Chain Run upon error."""
|
"""Process the Chain Run upon error."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_tool_start(self, run: Run) -> None:
|
def _on_tool_start(self, run: Run) -> None:
|
||||||
"""Process the Tool Run upon start."""
|
"""Process the Tool Run upon start."""
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
self._submit(self._persist_run_single, run.copy(deep=True))
|
self._submit(self._persist_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_tool_end(self, run: Run) -> None:
|
def _on_tool_end(self, run: Run) -> None:
|
||||||
"""Process the Tool Run."""
|
"""Process the Tool Run."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_tool_error(self, run: Run) -> None:
|
def _on_tool_error(self, run: Run) -> None:
|
||||||
"""Process the Tool Run upon error."""
|
"""Process the Tool Run upon error."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_retriever_start(self, run: Run) -> None:
|
def _on_retriever_start(self, run: Run) -> None:
|
||||||
"""Process the Retriever Run upon start."""
|
"""Process the Retriever Run upon start."""
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
self._submit(self._persist_run_single, run.copy(deep=True))
|
self._submit(self._persist_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_retriever_end(self, run: Run) -> None:
|
def _on_retriever_end(self, run: Run) -> None:
|
||||||
"""Process the Retriever Run."""
|
"""Process the Retriever Run."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def _on_retriever_error(self, run: Run) -> None:
|
def _on_retriever_error(self, run: Run) -> None:
|
||||||
"""Process the Retriever Run upon error."""
|
"""Process the Retriever Run upon error."""
|
||||||
self._submit(self._update_run_single, run.copy(deep=True))
|
self._submit(self._update_run_single, _copy(run))
|
||||||
|
|
||||||
def wait_for_futures(self) -> None:
|
def wait_for_futures(self) -> None:
|
||||||
"""Wait for the given futures to complete."""
|
"""Wait for the given futures to complete."""
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -47,3 +48,17 @@ def test_example_id_assignment_threadsafe() -> None:
|
|||||||
}
|
}
|
||||||
tracer.wait_for_futures()
|
tracer.wait_for_futures()
|
||||||
assert example_ids == expected_example_ids
|
assert example_ids == expected_example_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_lock() -> None:
|
||||||
|
"""Test that example assigned at callback start/end is honored."""
|
||||||
|
|
||||||
|
client = unittest.mock.MagicMock(spec=Client)
|
||||||
|
tracer = LangChainTracer(client=client)
|
||||||
|
|
||||||
|
with unittest.mock.patch.object(tracer, "_persist_run_single", new=lambda _: _):
|
||||||
|
run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||||
|
lock = threading.Lock()
|
||||||
|
tracer.on_chain_start({"name": "example_1"}, {"input": lock}, run_id=run_id_1)
|
||||||
|
tracer.on_chain_end({}, run_id=run_id_1)
|
||||||
|
tracer.wait_for_futures()
|
||||||
|
Loading…
Reference in New Issue
Block a user