From beb38f4f4d7ee9aa4e89b278957f7ccd4c072381 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Sun, 16 Jul 2023 17:47:38 -0700 Subject: [PATCH] Share client in evaluation callback (#7807) Guarantee the evaluator traces go to same endpoint --- langchain/callbacks/manager.py | 6 ++++++ langchain/callbacks/tracers/evaluation.py | 7 +++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index b40afcaaef6..a516a05d15d 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -7,6 +7,7 @@ import os from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, Dict, @@ -45,6 +46,9 @@ from langchain.schema import ( ) from langchain.schema.messages import BaseMessage, get_buffer_string +if TYPE_CHECKING: + from langsmith import Client as LangSmithClient + logger = logging.getLogger(__name__) Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] @@ -144,6 +148,7 @@ def tracing_v2_enabled( *, example_id: Optional[Union[str, UUID]] = None, tags: Optional[List[str]] = None, + client: Optional[LangSmithClient] = None, ) -> Generator[None, None, None]: """Instruct LangChain to log all runs in context to LangSmith. @@ -168,6 +173,7 @@ def tracing_v2_enabled( example_id=example_id, project_name=project_name, tags=tags, + client=client, ) tracing_v2_callback_var.set(cb) yield diff --git a/langchain/callbacks/tracers/evaluation.py b/langchain/callbacks/tracers/evaluation.py index d950d112d9e..a4a39063bfc 100644 --- a/langchain/callbacks/tracers/evaluation.py +++ b/langchain/callbacks/tracers/evaluation.py @@ -10,6 +10,7 @@ from langsmith import Client, RunEvaluator from langchain.callbacks.manager import tracing_v2_enabled from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.langchain import _get_client from langchain.callbacks.tracers.schemas import Run logger = logging.getLogger(__name__) @@ -77,7 +78,7 @@ class EvaluatorCallbackHandler(BaseTracer): self.example_id = ( UUID(example_id) if isinstance(example_id, str) else example_id ) - self.client = client or Client() + self.client = client or _get_client() self.evaluators = evaluators self.executor = ThreadPoolExecutor( max_workers=max(max_workers or len(evaluators), 1) @@ -102,7 +103,9 @@ class EvaluatorCallbackHandler(BaseTracer): try: if self.project_name is None: self.client.evaluate_run(run, evaluator) - with tracing_v2_enabled(project_name=self.project_name, tags=["eval"]): + with tracing_v2_enabled( + project_name=self.project_name, tags=["eval"], client=self.client + ): self.client.evaluate_run(run, evaluator) except Exception as e: logger.error(