diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index b40afcaaef6..a516a05d15d 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -7,6 +7,7 @@ import os from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, Dict, @@ -45,6 +46,9 @@ from langchain.schema import ( ) from langchain.schema.messages import BaseMessage, get_buffer_string +if TYPE_CHECKING: + from langsmith import Client as LangSmithClient + logger = logging.getLogger(__name__) Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] @@ -144,6 +148,7 @@ def tracing_v2_enabled( *, example_id: Optional[Union[str, UUID]] = None, tags: Optional[List[str]] = None, + client: Optional[LangSmithClient] = None, ) -> Generator[None, None, None]: """Instruct LangChain to log all runs in context to LangSmith. @@ -168,6 +173,7 @@ def tracing_v2_enabled( example_id=example_id, project_name=project_name, tags=tags, + client=client, ) tracing_v2_callback_var.set(cb) yield diff --git a/langchain/callbacks/tracers/evaluation.py b/langchain/callbacks/tracers/evaluation.py index d950d112d9e..a4a39063bfc 100644 --- a/langchain/callbacks/tracers/evaluation.py +++ b/langchain/callbacks/tracers/evaluation.py @@ -10,6 +10,7 @@ from langsmith import Client, RunEvaluator from langchain.callbacks.manager import tracing_v2_enabled from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.langchain import _get_client from langchain.callbacks.tracers.schemas import Run logger = logging.getLogger(__name__) @@ -77,7 +78,7 @@ class EvaluatorCallbackHandler(BaseTracer): self.example_id = ( UUID(example_id) if isinstance(example_id, str) else example_id ) - self.client = client or Client() + self.client = client or _get_client() self.evaluators = evaluators self.executor = ThreadPoolExecutor( max_workers=max(max_workers or len(evaluators), 1) @@ -102,7 +103,9 @@ class EvaluatorCallbackHandler(BaseTracer): try: if self.project_name is None: self.client.evaluate_run(run, evaluator) - with tracing_v2_enabled(project_name=self.project_name, tags=["eval"]): + with tracing_v2_enabled( + project_name=self.project_name, tags=["eval"], client=self.client + ): self.client.evaluate_run(run, evaluator) except Exception as e: logger.error(