mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 03:15:11 +00:00
Add collect_runs callback (#9885)
This commit is contained in:
parent
3103f07e03
commit
907c57e324
@ -20,6 +20,7 @@ from langchain.callbacks.human import HumanApprovalCallbackHandler
|
|||||||
from langchain.callbacks.infino_callback import InfinoCallbackHandler
|
from langchain.callbacks.infino_callback import InfinoCallbackHandler
|
||||||
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
|
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
|
collect_runs,
|
||||||
get_openai_callback,
|
get_openai_callback,
|
||||||
tracing_enabled,
|
tracing_enabled,
|
||||||
tracing_v2_enabled,
|
tracing_v2_enabled,
|
||||||
@ -66,6 +67,7 @@ __all__ = [
|
|||||||
"get_openai_callback",
|
"get_openai_callback",
|
||||||
"tracing_enabled",
|
"tracing_enabled",
|
||||||
"tracing_v2_enabled",
|
"tracing_v2_enabled",
|
||||||
|
"collect_runs",
|
||||||
"wandb_tracing_enabled",
|
"wandb_tracing_enabled",
|
||||||
"FlyteCallbackHandler",
|
"FlyteCallbackHandler",
|
||||||
"SageMakerCallbackHandler",
|
"SageMakerCallbackHandler",
|
||||||
|
@ -38,6 +38,7 @@ from langchain.callbacks.base import (
|
|||||||
)
|
)
|
||||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
from langchain.callbacks.tracers import run_collector
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
@ -75,6 +76,11 @@ tracing_v2_callback_var: ContextVar[
|
|||||||
] = ContextVar( # noqa: E501
|
] = ContextVar( # noqa: E501
|
||||||
"tracing_callback_v2", default=None
|
"tracing_callback_v2", default=None
|
||||||
)
|
)
|
||||||
|
run_collector_var: ContextVar[
|
||||||
|
Optional[run_collector.RunCollectorCallbackHandler]
|
||||||
|
] = ContextVar( # noqa: E501
|
||||||
|
"run_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_debug() -> bool:
|
def _get_debug() -> bool:
|
||||||
@ -184,6 +190,24 @@ def tracing_v2_enabled(
|
|||||||
tracing_v2_callback_var.set(None)
|
tracing_v2_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
|
||||||
|
"""Collect all run traces in context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with collect_runs() as runs_cb:
|
||||||
|
chain.invoke("foo")
|
||||||
|
run_id = runs_cb.traced_runs[0].id
|
||||||
|
"""
|
||||||
|
cb = run_collector.RunCollectorCallbackHandler()
|
||||||
|
run_collector_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
run_collector_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def trace_as_chain_group(
|
def trace_as_chain_group(
|
||||||
group_name: str,
|
group_name: str,
|
||||||
@ -1712,6 +1736,7 @@ def _configure(
|
|||||||
tracer_project = os.environ.get(
|
tracer_project = os.environ.get(
|
||||||
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
|
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
|
||||||
)
|
)
|
||||||
|
run_collector_ = run_collector_var.get()
|
||||||
debug = _get_debug()
|
debug = _get_debug()
|
||||||
if (
|
if (
|
||||||
verbose
|
verbose
|
||||||
@ -1774,4 +1799,6 @@ def _configure(
|
|||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
):
|
):
|
||||||
callback_manager.add_handler(open_ai, True)
|
callback_manager.add_handler(open_ai, True)
|
||||||
|
if run_collector_ is not None:
|
||||||
|
callback_manager.add_handler(run_collector_, False)
|
||||||
return callback_manager
|
return callback_manager
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
"""Test the run collector."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from langchain.callbacks import collect_runs
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_runs() -> None:
|
||||||
|
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
|
||||||
|
with collect_runs() as cb:
|
||||||
|
llm.predict("hi")
|
||||||
|
assert cb.traced_runs
|
||||||
|
assert len(cb.traced_runs) == 1
|
||||||
|
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
||||||
|
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
|
Loading…
Reference in New Issue
Block a user