Add collect_runs callback (#9885)

This commit is contained in:
William FH 2023-08-28 15:30:41 -07:00 committed by GitHub
parent 3103f07e03
commit 907c57e324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from langchain.callbacks.human import HumanApprovalCallbackHandler
from langchain.callbacks.infino_callback import InfinoCallbackHandler from langchain.callbacks.infino_callback import InfinoCallbackHandler
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
collect_runs,
get_openai_callback, get_openai_callback,
tracing_enabled, tracing_enabled,
tracing_v2_enabled, tracing_v2_enabled,
@ -66,6 +67,7 @@ __all__ = [
"get_openai_callback", "get_openai_callback",
"tracing_enabled", "tracing_enabled",
"tracing_v2_enabled", "tracing_v2_enabled",
"collect_runs",
"wandb_tracing_enabled", "wandb_tracing_enabled",
"FlyteCallbackHandler", "FlyteCallbackHandler",
"SageMakerCallbackHandler", "SageMakerCallbackHandler",

View File

@ -38,6 +38,7 @@ from langchain.callbacks.base import (
) )
from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers import run_collector
from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1 from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
@ -75,6 +76,11 @@ tracing_v2_callback_var: ContextVar[
] = ContextVar( # noqa: E501 ] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None "tracing_callback_v2", default=None
) )
run_collector_var: ContextVar[
Optional[run_collector.RunCollectorCallbackHandler]
] = ContextVar( # noqa: E501
"run_collector", default=None
)
def _get_debug() -> bool: def _get_debug() -> bool:
@ -184,6 +190,24 @@ def tracing_v2_enabled(
tracing_v2_callback_var.set(None) tracing_v2_callback_var.set(None)
@contextmanager
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
"""Collect all run traces in context.
Returns:
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
Example:
>>> with collect_runs() as runs_cb:
chain.invoke("foo")
run_id = runs_cb.traced_runs[0].id
"""
cb = run_collector.RunCollectorCallbackHandler()
run_collector_var.set(cb)
yield cb
run_collector_var.set(None)
@contextmanager @contextmanager
def trace_as_chain_group( def trace_as_chain_group(
group_name: str, group_name: str,
@ -1712,6 +1736,7 @@ def _configure(
tracer_project = os.environ.get( tracer_project = os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default") "LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
) )
run_collector_ = run_collector_var.get()
debug = _get_debug() debug = _get_debug()
if ( if (
verbose verbose
@ -1774,4 +1799,6 @@ def _configure(
for handler in callback_manager.handlers for handler in callback_manager.handlers
): ):
callback_manager.add_handler(open_ai, True) callback_manager.add_handler(open_ai, True)
if run_collector_ is not None:
callback_manager.add_handler(run_collector_, False)
return callback_manager return callback_manager

View File

@ -0,0 +1,16 @@
"""Test the run collector."""
import uuid
from langchain.callbacks import collect_runs
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_collect_runs() -> None:
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
with collect_runs() as cb:
llm.predict("hi")
assert cb.traced_runs
assert len(cb.traced_runs) == 1
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}