mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-03 21:54:04 +00:00
Update to RunOnDataset helper functions to accept evaluator callbacks (#6629)
Also improve docstrings and update the tracing datasets notebook to focus on "debug, evaluate, monitor"
This commit is contained in:
parent
7ac9b22886
commit
6ca383ecf6
84
langchain/callbacks/tracers/evaluation.py
Normal file
84
langchain/callbacks/tracers/evaluation.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""A tracer that runs evaluators over completed runs."""
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional, Sequence, Set, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
class EvaluatorCallbackHandler(BaseTracer):
|
||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evaluators : Sequence[RunEvaluator]
|
||||
The run evaluators to apply to all top level runs.
|
||||
max_workers : int, optional
|
||||
The maximum number of worker threads to use for running the evaluators.
|
||||
If not specified, it will default to the number of evaluators.
|
||||
client : LangChainPlusClient, optional
|
||||
The LangChainPlusClient instance to use for evaluating the runs.
|
||||
If not specified, a new instance will be created.
|
||||
example_id : Union[UUID, str], optional
|
||||
The example ID to be associated with the runs.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
example_id : Union[UUID, None]
|
||||
The example ID associated with the runs.
|
||||
client : LangChainPlusClient
|
||||
The LangChainPlusClient instance used for evaluating the runs.
|
||||
evaluators : Sequence[RunEvaluator]
|
||||
The sequence of run evaluators to be executed.
|
||||
executor : ThreadPoolExecutor
|
||||
The thread pool executor used for running the evaluators.
|
||||
futures : Set[Future]
|
||||
The set of futures representing the running evaluators.
|
||||
"""
|
||||
|
||||
name = "evaluator_callback_handler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evaluators: Sequence[RunEvaluator],
|
||||
max_workers: Optional[int] = None,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.client = client or LangChainPlusClient()
|
||||
self.evaluators = evaluators
|
||||
self.executor = ThreadPoolExecutor(
|
||||
max_workers=max(max_workers or len(evaluators), 1)
|
||||
)
|
||||
self.futures: Set[Future] = set()
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Run the evaluator on the run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
run : Run
|
||||
The run to be evaluated.
|
||||
|
||||
"""
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
for evaluator in self.evaluators:
|
||||
self.futures.add(
|
||||
self.executor.submit(self.client.evaluate_run, run_, evaluator)
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for all futures to complete."""
|
||||
futures = list(self.futures)
|
||||
wait(futures)
|
||||
for future in futures:
|
||||
self.futures.remove(future)
|
@ -1,20 +1,52 @@
|
||||
"""A tracer that collects all nested runs in a list."""
|
||||
from typing import Any, List
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
class RunCollectorCallbackHandler(BaseTracer):
|
||||
"""A tracer that collects all nested runs in a list.
|
||||
"""
|
||||
A tracer that collects all nested runs in a list.
|
||||
|
||||
Useful for inspection and for evaluation."""
|
||||
This tracer is useful for inspection and evaluation purposes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example_id : Optional[Union[UUID, str]], default=None
|
||||
The ID of the example being traced. It can be either a UUID or a string.
|
||||
"""
|
||||
|
||||
name = "run-collector_callback_handler"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RunCollectorCallbackHandler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example_id : Optional[Union[UUID, str]], default=None
|
||||
The ID of the example being traced. It can be either a UUID or a string.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.traced_runs: List[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
self.traced_runs.append(run)
|
||||
"""
|
||||
Persist a run by adding it to the traced_runs list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
run : Run
|
||||
The run to be persisted.
|
||||
"""
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
self.traced_runs.append(run_)
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Utilities for running LLMs/Chains over datasets."""
|
||||
"""Utilities for running language models or Chains over datasets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@ -13,15 +14,18 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchainplus_sdk import LangChainPlusClient
|
||||
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
|
||||
from langchainplus_sdk.schemas import Example
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
@ -41,11 +45,21 @@ MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||
|
||||
|
||||
class InputFormatError(Exception):
|
||||
"""Raised when input format is invalid."""
|
||||
"""Raised when the input format is invalid."""
|
||||
|
||||
|
||||
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
||||
"""Get prompts from inputs."""
|
||||
"""
|
||||
Get prompts from inputs.
|
||||
|
||||
Args:
|
||||
inputs: The input dictionary.
|
||||
|
||||
Returns:
|
||||
A list of prompts.
|
||||
Raises:
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if not inputs:
|
||||
raise InputFormatError("Inputs should not be empty.")
|
||||
|
||||
@ -83,7 +97,17 @@ def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
||||
|
||||
|
||||
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
||||
"""Get Chat Messages from inputs."""
|
||||
"""
|
||||
Get Chat Messages from inputs.
|
||||
|
||||
Args:
|
||||
inputs: The input dictionary.
|
||||
|
||||
Returns:
|
||||
A list of chat messages.
|
||||
Raises:
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if not inputs:
|
||||
raise InputFormatError("Inputs should not be empty.")
|
||||
|
||||
@ -112,13 +136,25 @@ def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
||||
async def _arun_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
langchain_tracer: Optional[LangChainTracer],
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = (
|
||||
[langchain_tracer] if langchain_tracer else None
|
||||
)
|
||||
"""
|
||||
Asynchronously run the language model.
|
||||
|
||||
Args:
|
||||
llm: The language model to run.
|
||||
inputs: The input dictionary.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
|
||||
Returns:
|
||||
The LLMResult or ChatResult.
|
||||
Raises:
|
||||
ValueError: If the LLM type is unsupported.
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if isinstance(llm, BaseLLM):
|
||||
try:
|
||||
llm_prompts = _get_prompts(inputs)
|
||||
@ -152,18 +188,32 @@ async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
n_repetitions: int,
|
||||
langchain_tracer: Optional[LangChainTracer],
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain asynchronously."""
|
||||
if langchain_tracer is not None:
|
||||
previous_example_id = langchain_tracer.example_id
|
||||
langchain_tracer.example_id = example.id
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
|
||||
"""
|
||||
Asynchronously run the Chain or language model.
|
||||
|
||||
Args:
|
||||
example: The example to run.
|
||||
llm_or_chain_factory: The Chain or language model constructor to run.
|
||||
n_repetitions: The number of times to run the model on each example.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
|
||||
Returns:
|
||||
A list of outputs.
|
||||
"""
|
||||
if callbacks:
|
||||
previous_example_ids = [
|
||||
getattr(tracer, "example_id", None) for tracer in callbacks
|
||||
]
|
||||
for tracer in callbacks:
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example.id
|
||||
else:
|
||||
previous_example_id = None
|
||||
callbacks = None
|
||||
previous_example_ids = None
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
@ -171,8 +221,8 @@ async def _arun_llm_or_chain(
|
||||
output: Any = await _arun_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
langchain_tracer,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
@ -183,15 +233,19 @@ async def _arun_llm_or_chain(
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
if langchain_tracer is not None:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example_id
|
||||
return outputs
|
||||
|
||||
|
||||
async def _gather_with_concurrency(
|
||||
n: int,
|
||||
initializer: Callable[[], Coroutine[Any, Any, Optional[LangChainTracer]]],
|
||||
*async_funcs: Callable[[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]],
|
||||
initializer: Callable[[], Coroutine[Any, Any, Any]],
|
||||
*async_funcs: Callable[
|
||||
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
|
||||
],
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Run coroutines with a concurrency limit.
|
||||
@ -207,37 +261,42 @@ async def _gather_with_concurrency(
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
job_state = {"num_processed": 0}
|
||||
|
||||
tracer_queue: asyncio.Queue[Optional[LangChainTracer]] = asyncio.Queue()
|
||||
callback_queue: asyncio.Queue[Sequence[BaseCallbackHandler]] = asyncio.Queue()
|
||||
for _ in range(n):
|
||||
tracer_queue.put_nowait(await initializer())
|
||||
callback_queue.put_nowait(await initializer())
|
||||
|
||||
async def run_coroutine_with_semaphore(
|
||||
async_func: Callable[
|
||||
[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]
|
||||
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
|
||||
]
|
||||
) -> Any:
|
||||
async with semaphore:
|
||||
tracer = await tracer_queue.get()
|
||||
callbacks = await callback_queue.get()
|
||||
try:
|
||||
result = await async_func(tracer, job_state)
|
||||
result = await async_func(callbacks, job_state)
|
||||
finally:
|
||||
tracer_queue.put_nowait(tracer)
|
||||
callback_queue.put_nowait(callbacks)
|
||||
return result
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||
)
|
||||
while tracer_queue:
|
||||
while callback_queue:
|
||||
try:
|
||||
tracer = tracer_queue.get_nowait()
|
||||
callbacks = callback_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if tracer:
|
||||
tracer.wait_for_futures()
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, (LangChainTracer, EvaluatorCallbackHandler)):
|
||||
callback.wait_for_futures()
|
||||
return results
|
||||
|
||||
|
||||
async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChainTracer]:
|
||||
async def _callbacks_initializer(
|
||||
project_name: Optional[str],
|
||||
client: LangChainPlusClient,
|
||||
run_evaluators: Sequence[RunEvaluator],
|
||||
) -> List[BaseTracer]:
|
||||
"""
|
||||
Initialize a tracer to share across tasks.
|
||||
|
||||
@ -247,11 +306,19 @@ async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChain
|
||||
Returns:
|
||||
A LangChainTracer instance with an active project.
|
||||
"""
|
||||
callbacks: List[BaseTracer] = []
|
||||
if project_name:
|
||||
tracer = LangChainTracer(project_name=project_name)
|
||||
return tracer
|
||||
else:
|
||||
return None
|
||||
callbacks.append(LangChainTracer(project_name=project_name))
|
||||
if run_evaluators:
|
||||
callbacks.append(
|
||||
EvaluatorCallbackHandler(
|
||||
client=client,
|
||||
evaluators=run_evaluators,
|
||||
# We already have concurrency, don't want to overload the machine
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
|
||||
async def arun_on_examples(
|
||||
@ -262,13 +329,16 @@ async def arun_on_examples(
|
||||
num_repetitions: int = 1,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the chain on examples and store traces to the specified project name.
|
||||
Asynchronously run the chain on examples and store traces
|
||||
to the specified project name.
|
||||
|
||||
Args:
|
||||
examples: Examples to run the model or chain over
|
||||
examples: Examples to run the model or chain over.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
@ -277,24 +347,35 @@ async def arun_on_examples(
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
project_name: Project name to use when tracing runs.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to the traces.
|
||||
client: Client to use to read the dataset. If not provided, a new
|
||||
client will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
|
||||
client_ = client or LangChainPlusClient()
|
||||
client_.create_project(project_name, mode="eval")
|
||||
|
||||
results: Dict[str, List[Any]] = {}
|
||||
evaluation_handler = EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [], client=client_
|
||||
)
|
||||
|
||||
async def process_example(
|
||||
example: Example, tracer: Optional[LangChainTracer], job_state: dict
|
||||
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
|
||||
) -> None:
|
||||
"""Process a single example."""
|
||||
result = await _arun_llm_or_chain(
|
||||
example,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
tracer,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
job_state["num_processed"] += 1
|
||||
@ -307,9 +388,15 @@ async def arun_on_examples(
|
||||
|
||||
await _gather_with_concurrency(
|
||||
concurrency_level,
|
||||
functools.partial(_tracer_initializer, project_name),
|
||||
functools.partial(
|
||||
_callbacks_initializer,
|
||||
project_name=project_name,
|
||||
client=client_,
|
||||
run_evaluators=run_evaluators or [],
|
||||
),
|
||||
*(functools.partial(process_example, e) for e in examples),
|
||||
)
|
||||
evaluation_handler.wait_for_futures()
|
||||
return results
|
||||
|
||||
|
||||
@ -320,7 +407,21 @@ def run_llm(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
"""Run the language model on the example."""
|
||||
"""
|
||||
Run the language model on the example.
|
||||
|
||||
Args:
|
||||
llm: The language model to run.
|
||||
inputs: The input dictionary.
|
||||
callbacks: The callbacks to use during the run.
|
||||
tags: Optional tags to add to the run.
|
||||
|
||||
Returns:
|
||||
The LLMResult or ChatResult.
|
||||
Raises:
|
||||
ValueError: If the LLM type is unsupported.
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if isinstance(llm, BaseLLM):
|
||||
try:
|
||||
llm_prompts = _get_prompts(inputs)
|
||||
@ -350,18 +451,32 @@ def run_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
n_repetitions: int,
|
||||
langchain_tracer: Optional[LangChainTracer] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain synchronously."""
|
||||
if langchain_tracer is not None:
|
||||
previous_example_id = langchain_tracer.example_id
|
||||
langchain_tracer.example_id = example.id
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
|
||||
"""
|
||||
Run the Chain or language model synchronously.
|
||||
|
||||
Args:
|
||||
example: The example to run.
|
||||
llm_or_chain_factory: The Chain or language model constructor to run.
|
||||
n_repetitions: The number of times to run the model on each example.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
|
||||
Returns:
|
||||
A list of outputs.
|
||||
"""
|
||||
if callbacks:
|
||||
previous_example_ids = [
|
||||
getattr(tracer, "example_id", None) for tracer in callbacks
|
||||
]
|
||||
for tracer in callbacks:
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example.id
|
||||
else:
|
||||
previous_example_id = None
|
||||
callbacks = None
|
||||
previous_example_ids = None
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
@ -376,8 +491,10 @@ def run_llm_or_chain(
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
if langchain_tracer is not None:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example_id
|
||||
return outputs
|
||||
|
||||
|
||||
@ -388,48 +505,74 @@ def run_on_examples(
|
||||
num_repetitions: int = 1,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on examples and store traces to the specified project name.
|
||||
"""
|
||||
Run the Chain or language model on examples and store
|
||||
traces to the specified project name.
|
||||
|
||||
Args:
|
||||
examples: Examples to run model or chain over.
|
||||
examples: Examples to run the model or chain over.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
concurrency_level: Number of async workers to run in parallel.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
project_name: Project name to use when tracing runs.
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to the run traces.
|
||||
client: Client to use to access the dataset. If None, a new client
|
||||
will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
results: Dict[str, Any] = {}
|
||||
tracer = LangChainTracer(project_name=project_name) if project_name else None
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
|
||||
client_ = client or LangChainPlusClient()
|
||||
client_.create_project(project_name, mode="eval")
|
||||
tracer = LangChainTracer(project_name=project_name)
|
||||
evalution_handler = EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [], client=client_
|
||||
)
|
||||
callbacks: List[BaseCallbackHandler] = [tracer, evalution_handler]
|
||||
for i, example in enumerate(examples):
|
||||
result = run_llm_or_chain(
|
||||
example,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
langchain_tracer=tracer,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
if tracer:
|
||||
tracer.wait_for_futures()
|
||||
tracer.wait_for_futures()
|
||||
evalution_handler.wait_for_futures()
|
||||
return results
|
||||
|
||||
|
||||
def _get_project_name(
|
||||
project_name: Optional[str],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
dataset_name: str,
|
||||
dataset_name: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
Get the project name.
|
||||
|
||||
Args:
|
||||
project_name: The project name if manually specified.
|
||||
llm_or_chain_factory: The Chain or language model constructor.
|
||||
dataset_name: The dataset name.
|
||||
|
||||
Returns:
|
||||
The project name.
|
||||
"""
|
||||
if project_name is not None:
|
||||
return project_name
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
@ -437,7 +580,8 @@ def _get_project_name(
|
||||
model_name = llm_or_chain_factory.__class__.__name__
|
||||
else:
|
||||
model_name = llm_or_chain_factory().__class__.__name__
|
||||
return f"{dataset_name}-{model_name}-{current_time}"
|
||||
dataset_prefix = f"{dataset_name}-" if dataset_name else ""
|
||||
return f"{dataset_prefix}{model_name}-{current_time}"
|
||||
|
||||
|
||||
async def arun_on_dataset(
|
||||
@ -450,12 +594,13 @@ async def arun_on_dataset(
|
||||
verbose: bool = False,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the chain on a dataset and store traces to the specified project name.
|
||||
Asynchronously run the Chain or language model on a dataset
|
||||
and store traces to the specified project name.
|
||||
|
||||
Args:
|
||||
client: Client to use to read the dataset.
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
@ -469,7 +614,8 @@ async def arun_on_dataset(
|
||||
verbose: Whether to print progress.
|
||||
client: Client to use to read the dataset. If not provided, a new
|
||||
client will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the sesssion.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the run's project name and the resulting model outputs.
|
||||
@ -478,7 +624,6 @@ async def arun_on_dataset(
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
|
||||
dataset = client_.read_dataset(dataset_name=dataset_name)
|
||||
examples = client_.list_examples(dataset_id=str(dataset.id))
|
||||
|
||||
results = await arun_on_examples(
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
@ -486,7 +631,9 @@ async def arun_on_dataset(
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
client=client_,
|
||||
tags=tags,
|
||||
run_evaluators=run_evaluators,
|
||||
)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
@ -503,8 +650,11 @@ def run_on_dataset(
|
||||
verbose: bool = False,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on a dataset and store traces to the specified project name.
|
||||
"""
|
||||
Run the Chain or language model on a dataset and store traces
|
||||
to the specified project name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
@ -520,7 +670,8 @@ def run_on_dataset(
|
||||
verbose: Whether to print progress.
|
||||
client: Client to use to access the dataset. If None, a new client
|
||||
will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the sesssion.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the run's project name and the resulting model outputs.
|
||||
@ -536,6 +687,8 @@ def run_on_dataset(
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
run_evaluators=run_evaluators,
|
||||
client=client_,
|
||||
)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
|
@ -117,10 +117,12 @@ def get_qa_evaluator(
|
||||
choices_map={"CORRECT": 1, "INCORRECT": 0},
|
||||
),
|
||||
)
|
||||
tags = kwargs.pop("tags", [])
|
||||
return RunEvaluatorChain(
|
||||
eval_chain=eval_chain,
|
||||
input_mapper=input_mapper,
|
||||
output_parser=output_parser,
|
||||
tags=tags + [evaluation_name],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -174,6 +176,7 @@ def get_criteria_evaluator(
|
||||
choices_map={"Y": 1, "N": 0}, evaluation_name=evaluation_name
|
||||
),
|
||||
)
|
||||
tags = kwargs.pop("tags", [])
|
||||
eval_chain = CriteriaEvalChain.from_llm(
|
||||
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
|
||||
)
|
||||
@ -181,6 +184,7 @@ def get_criteria_evaluator(
|
||||
eval_chain=eval_chain,
|
||||
input_mapper=input_mapper,
|
||||
output_parser=parser,
|
||||
tags=tags + [evaluation_name],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -303,9 +307,11 @@ def get_trajectory_evaluator(
|
||||
TrajectoryEvalOutputParser(evaluation_name=evaluation_name),
|
||||
)
|
||||
eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs)
|
||||
tags = kwargs.pop("tags", [])
|
||||
return RunEvaluatorChain(
|
||||
eval_chain=eval_chain,
|
||||
input_mapper=input_mapper,
|
||||
output_parser=parser,
|
||||
tags=tags + [evaluation_name],
|
||||
**kwargs,
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -169,8 +169,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tracer: Any,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[Any] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
|
Loading…
Reference in New Issue
Block a user