mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
Share client in evaluation callback (#7807)
Guarantee the evaluator traces go to same endpoint
This commit is contained in:
parent
1db13e8a85
commit
beb38f4f4d
@ -7,6 +7,7 @@ import os
|
|||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Dict,
|
Dict,
|
||||||
@ -45,6 +46,9 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langsmith import Client as LangSmithClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||||
|
|
||||||
@ -144,6 +148,7 @@ def tracing_v2_enabled(
|
|||||||
*,
|
*,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
client: Optional[LangSmithClient] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||||
|
|
||||||
@ -168,6 +173,7 @@ def tracing_v2_enabled(
|
|||||||
example_id=example_id,
|
example_id=example_id,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
tracing_v2_callback_var.set(cb)
|
tracing_v2_callback_var.set(cb)
|
||||||
yield
|
yield
|
||||||
|
@ -10,6 +10,7 @@ from langsmith import Client, RunEvaluator
|
|||||||
|
|
||||||
from langchain.callbacks.manager import tracing_v2_enabled
|
from langchain.callbacks.manager import tracing_v2_enabled
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.langchain import _get_client
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -77,7 +78,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
self.example_id = (
|
self.example_id = (
|
||||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||||
)
|
)
|
||||||
self.client = client or Client()
|
self.client = client or _get_client()
|
||||||
self.evaluators = evaluators
|
self.evaluators = evaluators
|
||||||
self.executor = ThreadPoolExecutor(
|
self.executor = ThreadPoolExecutor(
|
||||||
max_workers=max(max_workers or len(evaluators), 1)
|
max_workers=max(max_workers or len(evaluators), 1)
|
||||||
@ -102,7 +103,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
try:
|
try:
|
||||||
if self.project_name is None:
|
if self.project_name is None:
|
||||||
self.client.evaluate_run(run, evaluator)
|
self.client.evaluate_run(run, evaluator)
|
||||||
with tracing_v2_enabled(project_name=self.project_name, tags=["eval"]):
|
with tracing_v2_enabled(
|
||||||
|
project_name=self.project_name, tags=["eval"], client=self.client
|
||||||
|
):
|
||||||
self.client.evaluate_run(run, evaluator)
|
self.client.evaluate_run(run, evaluator)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
Loading…
Reference in New Issue
Block a user