Share client in evaluation callback (#7807)

Guarantee the evaluator traces go to same endpoint
This commit is contained in:
William FH 2023-07-16 17:47:38 -07:00 committed by GitHub
parent 1db13e8a85
commit beb38f4f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -7,6 +7,7 @@ import os
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
AsyncGenerator, AsyncGenerator,
Dict, Dict,
@ -45,6 +46,9 @@ from langchain.schema import (
) )
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@ -144,6 +148,7 @@ def tracing_v2_enabled(
*, *,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
client: Optional[LangSmithClient] = None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith. """Instruct LangChain to log all runs in context to LangSmith.
@ -168,6 +173,7 @@ def tracing_v2_enabled(
example_id=example_id, example_id=example_id,
project_name=project_name, project_name=project_name,
tags=tags, tags=tags,
client=client,
) )
tracing_v2_callback_var.set(cb) tracing_v2_callback_var.set(cb)
yield yield

View File

@ -10,6 +10,7 @@ from langsmith import Client, RunEvaluator
from langchain.callbacks.manager import tracing_v2_enabled from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.langchain import _get_client
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,7 +78,7 @@ class EvaluatorCallbackHandler(BaseTracer):
self.example_id = ( self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id UUID(example_id) if isinstance(example_id, str) else example_id
) )
self.client = client or Client() self.client = client or _get_client()
self.evaluators = evaluators self.evaluators = evaluators
self.executor = ThreadPoolExecutor( self.executor = ThreadPoolExecutor(
max_workers=max(max_workers or len(evaluators), 1) max_workers=max(max_workers or len(evaluators), 1)
@ -102,7 +103,9 @@ class EvaluatorCallbackHandler(BaseTracer):
try: try:
if self.project_name is None: if self.project_name is None:
self.client.evaluate_run(run, evaluator) self.client.evaluate_run(run, evaluator)
with tracing_v2_enabled(project_name=self.project_name, tags=["eval"]): with tracing_v2_enabled(
project_name=self.project_name, tags=["eval"], client=self.client
):
self.client.evaluate_run(run, evaluator) self.client.evaluate_run(run, evaluator)
except Exception as e: except Exception as e:
logger.error( logger.error(