mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
Add collect_runs callback (#9885)
This commit is contained in:
parent
3103f07e03
commit
907c57e324
@ -20,6 +20,7 @@ from langchain.callbacks.human import HumanApprovalCallbackHandler
|
||||
from langchain.callbacks.infino_callback import InfinoCallbackHandler
|
||||
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
|
||||
from langchain.callbacks.manager import (
|
||||
collect_runs,
|
||||
get_openai_callback,
|
||||
tracing_enabled,
|
||||
tracing_v2_enabled,
|
||||
@ -66,6 +67,7 @@ __all__ = [
|
||||
"get_openai_callback",
|
||||
"tracing_enabled",
|
||||
"tracing_v2_enabled",
|
||||
"collect_runs",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
|
@ -38,6 +38,7 @@ from langchain.callbacks.base import (
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers import run_collector
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
@ -75,6 +76,11 @@ tracing_v2_callback_var: ContextVar[
|
||||
] = ContextVar( # noqa: E501
|
||||
"tracing_callback_v2", default=None
|
||||
)
|
||||
run_collector_var: ContextVar[
|
||||
Optional[run_collector.RunCollectorCallbackHandler]
|
||||
] = ContextVar( # noqa: E501
|
||||
"run_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def _get_debug() -> bool:
|
||||
@ -184,6 +190,24 @@ def tracing_v2_enabled(
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
|
||||
"""Collect all run traces in context.
|
||||
|
||||
Returns:
|
||||
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||
|
||||
Example:
|
||||
>>> with collect_runs() as runs_cb:
|
||||
chain.invoke("foo")
|
||||
run_id = runs_cb.traced_runs[0].id
|
||||
"""
|
||||
cb = run_collector.RunCollectorCallbackHandler()
|
||||
run_collector_var.set(cb)
|
||||
yield cb
|
||||
run_collector_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_as_chain_group(
|
||||
group_name: str,
|
||||
@ -1712,6 +1736,7 @@ def _configure(
|
||||
tracer_project = os.environ.get(
|
||||
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
|
||||
)
|
||||
run_collector_ = run_collector_var.get()
|
||||
debug = _get_debug()
|
||||
if (
|
||||
verbose
|
||||
@ -1774,4 +1799,6 @@ def _configure(
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(open_ai, True)
|
||||
if run_collector_ is not None:
|
||||
callback_manager.add_handler(run_collector_, False)
|
||||
return callback_manager
|
||||
|
@ -0,0 +1,16 @@
|
||||
"""Test the run collector."""
|
||||
|
||||
import uuid
|
||||
|
||||
from langchain.callbacks import collect_runs
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_collect_runs() -> None:
|
||||
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
|
||||
with collect_runs() as cb:
|
||||
llm.predict("hi")
|
||||
assert cb.traced_runs
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
||||
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
|
Loading…
Reference in New Issue
Block a user