Compare commits

...

2 Commits

Author SHA1 Message Date
William Fu-Hinthorn
9ae1fb0be1 Skip evaluation of unfinished runs 2023-07-05 14:56:05 -07:00
William Fu-Hinthorn
e3c47d0817 Add tags support for langchaintracer 2023-07-05 12:10:52 -07:00
3 changed files with 26 additions and 2 deletions

View File

@@ -144,6 +144,7 @@ def tracing_v2_enabled(
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> Generator[None, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
@@ -152,6 +153,8 @@ def tracing_v2_enabled(
Defaults to "default".
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The tags to add to the run.
Defaults to None.
Returns:
None
@@ -170,6 +173,7 @@ def tracing_v2_enabled(
cb = LangChainTracer(
example_id=example_id,
project_name=project_name,
tags=tags,
)
tracing_v2_callback_var.set(cb)
yield

View File

@@ -40,6 +40,9 @@ class EvaluatorCallbackHandler(BaseTracer):
The thread pool executor used for running the evaluators.
futures : Set[Future]
The set of futures representing the running evaluators.
skip_unfinished : bool
Whether to skip runs that are not finished or raised
an error.
"""
name = "evaluator_callback_handler"
@@ -50,6 +53,7 @@ class EvaluatorCallbackHandler(BaseTracer):
max_workers: Optional[int] = None,
client: Optional[LangChainPlusClient] = None,
example_id: Optional[Union[UUID, str]] = None,
skip_unfinished: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -62,6 +66,7 @@ class EvaluatorCallbackHandler(BaseTracer):
max_workers=max(max_workers or len(evaluators), 1)
)
self.futures: Set[Future] = set()
self.skip_unfinished = skip_unfinished
def _evaluate_run(self, run: Run, evaluator: RunEvaluator) -> None:
try:
@@ -70,7 +75,6 @@ class EvaluatorCallbackHandler(BaseTracer):
logger.error(
f"Error evaluating run {run.id} with "
f"{evaluator.__class__.__name__}: {e}",
exc_info=True,
)
raise e
@@ -83,6 +87,8 @@ class EvaluatorCallbackHandler(BaseTracer):
The run to be evaluated.
"""
if self.skip_unfinished and not run.outputs:
return
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:

View File

@@ -44,6 +44,7 @@ class LangChainTracer(BaseTracer):
example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer."""
@@ -59,6 +60,7 @@ class LangChainTracer(BaseTracer):
self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient()
self._futures: Set[Future] = set()
self.tags = tags or []
global _TRACERS
_TRACERS.append(self)
@@ -98,11 +100,21 @@ class LangChainTracer(BaseTracer):
def _persist_run(self, run: Run) -> None:
"""The Langchain Tracer uses Post/Patch rather than persist."""
def _get_tags(self, run: Run) -> List[str]:
"""Add tags to a run."""
tags = run.tags or []
if self.tags:
for tag in self.tags:
if tag not in tags:
tags.append(tag)
return tags
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
run_dict = run.dict(exclude={"child_runs"})
run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {})
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
@@ -116,7 +128,9 @@ class LangChainTracer(BaseTracer):
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
try:
self.client.update_run(run.id, **run.dict())
run_dict = run.dict()
run_dict["tags"] = self._get_tags(run)
self.client.update_run(run.id, **run_dict)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("patch", e)