mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Tracing Group (#5326)
Add context manager to group all runs under a virtual parent --------- Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
This commit is contained in:
parent
d5b1608216
commit
84a46753ab
@ -5,9 +5,20 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import langchain
|
||||
@ -116,6 +127,58 @@ def tracing_v2_enabled(
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_as_chain_group(
|
||||
group_name: str,
|
||||
*,
|
||||
session_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
session_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Generator[CallbackManager, None, None]:
|
||||
"""Get a callback manager for a chain group in a context manager."""
|
||||
cb = LangChainTracer(
|
||||
tenant_id=tenant_id,
|
||||
session_name=session_name,
|
||||
example_id=example_id,
|
||||
session_extra=session_extra,
|
||||
)
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[cb],
|
||||
)
|
||||
|
||||
run_manager = cm.on_chain_start({"name": group_name}, {})
|
||||
yield run_manager.get_child()
|
||||
run_manager.on_chain_end({})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def atrace_as_chain_group(
|
||||
group_name: str,
|
||||
*,
|
||||
session_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
session_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AsyncCallbackManager, None]:
|
||||
"""Get a callback manager for a chain group in a context manager."""
|
||||
cb = LangChainTracer(
|
||||
tenant_id=tenant_id,
|
||||
session_name=session_name,
|
||||
example_id=example_id,
|
||||
session_extra=session_extra,
|
||||
)
|
||||
cm = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=[cb],
|
||||
)
|
||||
|
||||
run_manager = await cm.on_chain_start({"name": group_name}, {})
|
||||
try:
|
||||
yield run_manager.get_child()
|
||||
finally:
|
||||
await run_manager.on_chain_end({})
|
||||
|
||||
|
||||
def _handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
@ -68,7 +68,7 @@ class LangChainTracer(BaseTracer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
example_id: Optional[UUID] = None,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
session_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -77,7 +77,9 @@ class LangChainTracer(BaseTracer):
|
||||
self.session: Optional[TracerSession] = None
|
||||
self._endpoint = get_endpoint()
|
||||
self._headers = get_headers()
|
||||
self.example_id = example_id
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||
# set max_workers to 1 to process tasks in order
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
@ -7,9 +7,15 @@ from aiohttp import ClientSession
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.callbacks import tracing_enabled
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks.manager import (
|
||||
atrace_as_chain_group,
|
||||
trace_as_chain_group,
|
||||
tracing_v2_enabled,
|
||||
)
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
questions = [
|
||||
(
|
||||
@ -152,3 +158,59 @@ def test_tracing_v2_context_manager() -> None:
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
||||
|
||||
def test_trace_as_group() -> None:
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["product"],
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
with trace_as_chain_group("my_group") as group_manager:
|
||||
chain.run(product="cars", callbacks=group_manager)
|
||||
chain.run(product="computers", callbacks=group_manager)
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
|
||||
with trace_as_chain_group("my_group_2") as group_manager:
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
|
||||
|
||||
def test_trace_as_group_with_env_set() -> None:
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["product"],
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
with trace_as_chain_group("my_group") as group_manager:
|
||||
chain.run(product="cars", callbacks=group_manager)
|
||||
chain.run(product="computers", callbacks=group_manager)
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
|
||||
with trace_as_chain_group("my_group_2") as group_manager:
|
||||
chain.run(product="toys", callbacks=group_manager)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_as_group_async() -> None:
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["product"],
|
||||
template="What is a good name for a company that makes {product}?",
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
async with atrace_as_chain_group("my_group") as group_manager:
|
||||
await chain.arun(product="cars", callbacks=group_manager)
|
||||
await chain.arun(product="computers", callbacks=group_manager)
|
||||
await chain.arun(product="toys", callbacks=group_manager)
|
||||
|
||||
async with atrace_as_chain_group("my_group_2") as group_manager:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
chain.arun(product="toys", callbacks=group_manager),
|
||||
chain.arun(product="computers", callbacks=group_manager),
|
||||
chain.arun(product="cars", callbacks=group_manager),
|
||||
]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user