mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Share client in evaluation callback (#7807)
Guarantee the evaluator traces go to same endpoint
This commit is contained in:
parent
1db13e8a85
commit
beb38f4f4d
@ -7,6 +7,7 @@ import os
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
@ -45,6 +46,9 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langsmith import Client as LangSmithClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
@ -144,6 +148,7 @@ def tracing_v2_enabled(
|
||||
*,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
client: Optional[LangSmithClient] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||
|
||||
@ -168,6 +173,7 @@ def tracing_v2_enabled(
|
||||
example_id=example_id,
|
||||
project_name=project_name,
|
||||
tags=tags,
|
||||
client=client,
|
||||
)
|
||||
tracing_v2_callback_var.set(cb)
|
||||
yield
|
||||
|
@ -10,6 +10,7 @@ from langsmith import Client, RunEvaluator
|
||||
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.langchain import _get_client
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -77,7 +78,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.client = client or Client()
|
||||
self.client = client or _get_client()
|
||||
self.evaluators = evaluators
|
||||
self.executor = ThreadPoolExecutor(
|
||||
max_workers=max(max_workers or len(evaluators), 1)
|
||||
@ -102,7 +103,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
try:
|
||||
if self.project_name is None:
|
||||
self.client.evaluate_run(run, evaluator)
|
||||
with tracing_v2_enabled(project_name=self.project_name, tags=["eval"]):
|
||||
with tracing_v2_enabled(
|
||||
project_name=self.project_name, tags=["eval"], client=self.client
|
||||
):
|
||||
self.client.evaluate_run(run, evaluator)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
Loading…
Reference in New Issue
Block a user