Compare commits

...

2 Commits

Author SHA1 Message Date
William Fu-Hinthorn
d180208915 nhp 2023-09-08 14:07:49 -07:00
William Fu-Hinthorn
625e598111 . 2023-09-08 10:42:55 -07:00

View File

@@ -278,35 +278,21 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
## Shared data validation utilities ## Shared data validation utilities
def _validate_example_inputs_for_language_model( def _validate_example_inputs_for_language_model(
first_example: Example, first_example: Example,
input_mapper: Optional[Callable[[Dict], Any]],
) -> None: ) -> None:
if input_mapper: try:
prompt_input = input_mapper(first_example.inputs) _get_prompt(first_example.inputs)
if not isinstance(prompt_input, str) and not ( except InputFormatError:
isinstance(prompt_input, list)
and all(isinstance(msg, BaseMessage) for msg in prompt_input)
):
raise InputFormatError(
"When using an input_mapper to prepare dataset example inputs"
" for an LLM or chat model, the output must a single string or"
" a list of chat messages."
f"\nGot: {prompt_input} of type {type(prompt_input)}."
)
else:
try: try:
_get_prompt(first_example.inputs) _get_messages(first_example.inputs)
except InputFormatError: except InputFormatError:
try: raise InputFormatError(
_get_messages(first_example.inputs) "Example inputs do not match language model input format. "
except InputFormatError: "Expected a dictionary with messages or a single prompt."
raise InputFormatError( f" Got: {first_example.inputs}"
"Example inputs do not match language model input format. " " Please update your dataset OR provide an input_mapper"
"Expected a dictionary with messages or a single prompt." " to convert the example.inputs to a compatible format"
f" Got: {first_example.inputs}" " for the llm or chat model you wish to evaluate."
" Please update your dataset OR provide an input_mapper" )
" to convert the example.inputs to a compatible format"
" for the llm or chat model you wish to evaluate."
)
def _validate_example_inputs_for_chain( def _validate_example_inputs_for_chain(
@@ -349,16 +335,15 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs( def _validate_example_inputs(
example: Example, example: Example,
llm_or_chain_factory: MCF, llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]],
) -> None: ) -> None:
"""Validate that the example inputs are valid for the model.""" """Validate that the example inputs are valid for the model."""
if isinstance(llm_or_chain_factory, BaseLanguageModel): if isinstance(llm_or_chain_factory, BaseLanguageModel):
_validate_example_inputs_for_language_model(example, input_mapper) _validate_example_inputs_for_language_model(example)
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
if isinstance(chain, Chain): if isinstance(chain, Chain):
# Otherwise it's a runnable # Otherwise it's a runnable
_validate_example_inputs_for_chain(example, chain, input_mapper) _validate_example_inputs_for_chain(example, chain)
elif isinstance(chain, Runnable): elif isinstance(chain, Runnable):
logger.debug(f"Skipping input validation for {chain}") logger.debug(f"Skipping input validation for {chain}")
@@ -591,7 +576,6 @@ async def _arun_llm(
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[str, BaseMessage]: ) -> Union[str, BaseMessage]:
"""Asynchronously run the language model. """Asynchronously run the language model.
@@ -600,7 +584,6 @@ async def _arun_llm(
inputs: The input dictionary. inputs: The input dictionary.
tags: Optional tags to add to the run. tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run. callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map inputs to the expected format.
Returns: Returns:
The LLMResult or ChatResult. The LLMResult or ChatResult.
@@ -608,36 +591,16 @@ async def _arun_llm(
ValueError: If the LLM type is unsupported. ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid. InputFormatError: If the input format is invalid.
""" """
if input_mapper is not None: try:
prompt_or_messages = input_mapper(inputs) prompt = _get_prompt(inputs)
if isinstance(prompt_or_messages, str): llm_output: Union[str, BaseMessage] = await llm.apredict(
return await llm.apredict( prompt, callbacks=callbacks, tags=tags
prompt_or_messages, callbacks=callbacks, tags=tags )
) except InputFormatError:
elif isinstance(prompt_or_messages, list) and all( messages = _get_messages(inputs)
isinstance(msg, BaseMessage) for msg in prompt_or_messages llm_output = await llm.apredict_messages(
): messages, callbacks=callbacks, tags=tags
return await llm.apredict_messages( )
prompt_or_messages, callbacks=callbacks, tags=tags
)
else:
raise InputFormatError(
"Input mapper returned invalid format"
f" {prompt_or_messages}"
"\nExpected a single string or list of chat messages."
)
else:
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.apredict(
prompt, callbacks=callbacks, tags=tags
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.apredict_messages(
messages, callbacks=callbacks, tags=tags
)
return llm_output return llm_output
@@ -647,19 +610,17 @@ async def _arun_chain(
callbacks: Callbacks, callbacks: Callbacks,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str]: ) -> Union[dict, str]:
"""Run a chain asynchronously on inputs.""" """Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain): if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1: if isinstance(inputs, dict) and len(inputs) == 1:
val = next(iter(inputs_.values())) val = next(iter(inputs.values()))
output = await chain.acall(val, callbacks=callbacks, tags=tags) output = await chain.acall(val, callbacks=callbacks, tags=tags)
else: else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) output = await chain.acall(inputs, callbacks=callbacks, tags=tags)
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.ainvoke(inputs_, config=runnable_config) output = await chain.ainvoke(inputs, config=runnable_config)
return output return output
@@ -668,7 +629,6 @@ async def _arun_llm_or_chain(
config: RunnableConfig, config: RunnableConfig,
*, *,
llm_or_chain_factory: MCF, llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]: ) -> Union[dict, str, LLMResult, ChatResult]:
"""Asynchronously run the Chain or language model. """Asynchronously run the Chain or language model.
@@ -693,7 +653,6 @@ async def _arun_llm_or_chain(
example.inputs, example.inputs,
tags=config["tags"], tags=config["tags"],
callbacks=config["callbacks"], callbacks=config["callbacks"],
input_mapper=input_mapper,
) )
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
@@ -702,7 +661,6 @@ async def _arun_llm_or_chain(
example.inputs, example.inputs,
tags=config["tags"], tags=config["tags"],
callbacks=config["callbacks"], callbacks=config["callbacks"],
input_mapper=input_mapper,
) )
result = output result = output
except Exception as e: except Exception as e:
@@ -775,19 +733,17 @@ def _run_chain(
callbacks: Callbacks, callbacks: Callbacks,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[Dict, str]: ) -> Union[Dict, str]:
"""Run a chain on inputs.""" """Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain): if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1: if isinstance(inputs, dict) and len(inputs) == 1:
val = next(iter(inputs_.values())) val = next(iter(inputs.values()))
output = chain(val, callbacks=callbacks, tags=tags) output = chain(val, callbacks=callbacks, tags=tags)
else: else:
output = chain(inputs_, callbacks=callbacks, tags=tags) output = chain(inputs, callbacks=callbacks, tags=tags)
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain.invoke(inputs_, config=runnable_config) output = chain.invoke(inputs, config=runnable_config)
return output return output
@@ -796,7 +752,6 @@ def _run_llm_or_chain(
config: RunnableConfig, config: RunnableConfig,
*, *,
llm_or_chain_factory: MCF, llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]: ) -> Union[dict, str, LLMResult, ChatResult]:
""" """
Run the Chain or language model synchronously. Run the Chain or language model synchronously.
@@ -822,7 +777,6 @@ def _run_llm_or_chain(
example.inputs, example.inputs,
config["callbacks"], config["callbacks"],
tags=config["tags"], tags=config["tags"],
input_mapper=input_mapper,
) )
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
@@ -831,7 +785,6 @@ def _run_llm_or_chain(
example.inputs, example.inputs,
config["callbacks"], config["callbacks"],
tags=config["tags"], tags=config["tags"],
input_mapper=input_mapper,
) )
result = output result = output
except Exception as e: except Exception as e:
@@ -879,9 +832,7 @@ def _prepare_run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str], project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None, evaluation: Optional[smith_eval.RunEvalConfig] = None,
tags: Optional[List[str]] = None, config: Optional[RunnableConfig] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]: ) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
project_name = project_name or name_generation.random_name() project_name = project_name or name_generation.random_name()
wrapped_model, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
@@ -891,8 +842,10 @@ def _prepare_run_on_dataset(
run_evaluators = _setup_evaluation( run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type wrapped_model, examples, evaluation, dataset.data_type
) )
_validate_example_inputs(examples[0], wrapped_model, input_mapper) _validate_example_inputs(examples[0], wrapped_model)
progress_bar = progress.ProgressBarCallback(len(examples)) progress_bar = progress.ProgressBarCallback(len(examples))
run_config = config or RunnableConfig(max_concurrency=5)
_callbacks = run_config["callbacks"] if "callbacks" in run_config else []
configs = [ configs = [
RunnableConfig( RunnableConfig(
callbacks=[ callbacks=[
@@ -909,9 +862,9 @@ def _prepare_run_on_dataset(
example_id=example.id, example_id=example.id,
), ),
progress_bar, progress_bar,
], ]
tags=tags or [], + _callbacks,
max_concurrency=concurrency_level, **run_config,
) )
for example in examples for example in examples
] ]
@@ -947,16 +900,14 @@ def _collect_test_results(
async def arun_on_dataset( async def arun_on_dataset(
client: Client,
dataset_name: str, dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*, *,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None, evaluation: Optional[smith_eval.RunEvalConfig] = None,
concurrency_level: int = 5, config: Optional[RunnableConfig] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -964,24 +915,22 @@ async def arun_on_dataset(
and store traces to the specified project name. and store traces to the specified project name.
Args: Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Either a runnable or a constructor for a runnable.
If your runnable contains state (e.g. an agent or
chain with memory), you should pass in a constructor
function that returns a new runnable instance for each
example row in the dataset.
client: LangSmith client to use to read the dataset, and to client: LangSmith client to use to read the dataset, and to
log feedback and run traces. log feedback and run traces.
dataset_name: Name of the dataset to run the chain on. evaluation: Optional evaluation configuration to specify which
llm_or_chain_factory: Language model or Chain constructor to run evaluators to run on the results of the chain and which
over the dataset. The Chain constructor is used to permit key(s) to consider during evaluation.
independent calls on each example without carrying over state. config: Optional configuration to specify concurrency, tags,
evaluation: Optional evaluation configuration to use when evaluating or other information to apply to the chain runs.
concurrency_level: The number of async tasks to run concurrently.
project_name: Name of the project to store the traces in. project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}. Defaults to a friendly name.
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns: Returns:
A dictionary containing the run's project name and the A dictionary containing the run's project name and the
resulting model outputs. resulting model outputs.
@@ -993,7 +942,6 @@ async def arun_on_dataset(
.. code-block:: python .. code-block:: python
from langsmith import Client
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset
@@ -1020,11 +968,9 @@ async def arun_on_dataset(
] ]
) )
client = Client()
await arun_on_dataset( await arun_on_dataset(
client, dataset_name="<my_dataset_name>",
"<my_dataset_name>", llm_or_chain_config=construct_chain,
construct_chain,
evaluation=evaluation_config, evaluation=evaluation_config,
) )
@@ -1060,7 +1006,6 @@ async def arun_on_dataset(
) )
await arun_on_dataset( await arun_on_dataset(
client,
"<my_dataset_name>", "<my_dataset_name>",
construct_chain, construct_chain,
evaluation=evaluation_config, evaluation=evaluation_config,
@@ -1073,15 +1018,14 @@ async def arun_on_dataset(
f"{kwargs.keys()}.", f"{kwargs.keys()}.",
DeprecationWarning, DeprecationWarning,
) )
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset( wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client, client,
dataset_name, dataset_name,
llm_or_chain_factory, llm_or_chain_factory,
project_name, project_name,
evaluation, evaluation,
tags, config=config,
input_mapper,
concurrency_level,
) )
batch_results = await runnable_utils.gather_with_concurrency( batch_results = await runnable_utils.gather_with_concurrency(
@@ -1090,7 +1034,6 @@ async def arun_on_dataset(
functools.partial( functools.partial(
_arun_llm_or_chain, _arun_llm_or_chain,
llm_or_chain_factory=wrapped_model, llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
), ),
examples, examples,
configs, configs,
@@ -1108,16 +1051,14 @@ async def arun_on_dataset(
def run_on_dataset( def run_on_dataset(
client: Client,
dataset_name: str, dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*, *,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None, evaluation: Optional[smith_eval.RunEvalConfig] = None,
concurrency_level: int = 5, config: Optional[RunnableConfig] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -1135,7 +1076,7 @@ def run_on_dataset(
results of the chain results of the chain
concurrency_level: The number of async tasks to run concurrently. concurrency_level: The number of async tasks to run concurrently.
project_name: Name of the project to store the traces in. project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}. Defaults to a friendly name
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to each run in the project. tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example input_mapper: A function to map to the inputs dictionary from an Example
@@ -1155,7 +1096,6 @@ def run_on_dataset(
.. code-block:: python .. code-block:: python
from langsmith import Client
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset
@@ -1182,12 +1122,11 @@ def run_on_dataset(
] ]
) )
client = Client()
run_on_dataset( run_on_dataset(
client, dataset_name="<my_dataset_name>",
"<my_dataset_name>", llm_or_chain_factory=construct_chain,
construct_chain,
evaluation=evaluation_config, evaluation=evaluation_config,
verbose=True,
) )
You can also create custom evaluators by subclassing the You can also create custom evaluators by subclassing the
@@ -1222,7 +1161,6 @@ def run_on_dataset(
) )
run_on_dataset( run_on_dataset(
client,
"<my_dataset_name>", "<my_dataset_name>",
construct_chain, construct_chain,
evaluation=evaluation_config, evaluation=evaluation_config,
@@ -1235,15 +1173,14 @@ def run_on_dataset(
f"{kwargs.keys()}.", f"{kwargs.keys()}.",
DeprecationWarning, DeprecationWarning,
) )
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset( wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client, client,
dataset_name, dataset_name,
llm_or_chain_factory, llm_or_chain_factory,
project_name, project_name,
evaluation, evaluation,
tags, config=config,
input_mapper,
concurrency_level,
) )
with runnable_config.get_executor_for_config(configs[0]) as executor: with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list( batch_results = list(
@@ -1251,7 +1188,6 @@ def run_on_dataset(
functools.partial( functools.partial(
_run_llm_or_chain, _run_llm_or_chain,
llm_or_chain_factory=wrapped_model, llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
), ),
examples, examples,
configs, configs,