Share client in evaluation callback (#7807)

Guarantee the evaluator traces go to same endpoint
This commit is contained in:
William FH 2023-07-16 17:47:38 -07:00 committed by GitHub
parent 1db13e8a85
commit beb38f4f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -7,6 +7,7 @@ import os
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
@ -45,6 +46,9 @@ from langchain.schema import (
)
from langchain.schema.messages import BaseMessage, get_buffer_string
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@ -144,6 +148,7 @@ def tracing_v2_enabled(
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
client: Optional[LangSmithClient] = None,
) -> Generator[None, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
@ -168,6 +173,7 @@ def tracing_v2_enabled(
example_id=example_id,
project_name=project_name,
tags=tags,
client=client,
)
tracing_v2_callback_var.set(cb)
yield

View File

@ -10,6 +10,7 @@ from langsmith import Client, RunEvaluator
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.langchain import _get_client
from langchain.callbacks.tracers.schemas import Run
logger = logging.getLogger(__name__)
@ -77,7 +78,7 @@ class EvaluatorCallbackHandler(BaseTracer):
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.client = client or Client()
self.client = client or _get_client()
self.evaluators = evaluators
self.executor = ThreadPoolExecutor(
max_workers=max(max_workers or len(evaluators), 1)
@ -102,7 +103,9 @@ class EvaluatorCallbackHandler(BaseTracer):
try:
if self.project_name is None:
self.client.evaluate_run(run, evaluator)
with tracing_v2_enabled(project_name=self.project_name, tags=["eval"]):
with tracing_v2_enabled(
project_name=self.project_name, tags=["eval"], client=self.client
):
self.client.evaluate_run(run, evaluator)
except Exception as e:
logger.error(