Compare commits

...

2 Commits

Author SHA1 Message Date
William Fu-Hinthorn
d180208915 nhp 2023-09-08 14:07:49 -07:00
William Fu-Hinthorn
625e598111 . 2023-09-08 10:42:55 -07:00

View File

@@ -278,35 +278,21 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
## Shared data validation utilities
def _validate_example_inputs_for_language_model(
first_example: Example,
input_mapper: Optional[Callable[[Dict], Any]],
) -> None:
if input_mapper:
prompt_input = input_mapper(first_example.inputs)
if not isinstance(prompt_input, str) and not (
isinstance(prompt_input, list)
and all(isinstance(msg, BaseMessage) for msg in prompt_input)
):
raise InputFormatError(
"When using an input_mapper to prepare dataset example inputs"
" for an LLM or chat model, the output must a single string or"
" a list of chat messages."
f"\nGot: {prompt_input} of type {type(prompt_input)}."
)
else:
try:
_get_prompt(first_example.inputs)
except InputFormatError:
try:
_get_prompt(first_example.inputs)
_get_messages(first_example.inputs)
except InputFormatError:
try:
_get_messages(first_example.inputs)
except InputFormatError:
raise InputFormatError(
"Example inputs do not match language model input format. "
"Expected a dictionary with messages or a single prompt."
f" Got: {first_example.inputs}"
" Please update your dataset OR provide an input_mapper"
" to convert the example.inputs to a compatible format"
" for the llm or chat model you wish to evaluate."
)
raise InputFormatError(
"Example inputs do not match language model input format. "
"Expected a dictionary with messages or a single prompt."
f" Got: {first_example.inputs}"
" Please update your dataset OR provide an input_mapper"
" to convert the example.inputs to a compatible format"
" for the llm or chat model you wish to evaluate."
)
def _validate_example_inputs_for_chain(
@@ -349,16 +335,15 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs(
example: Example,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]],
) -> None:
"""Validate that the example inputs are valid for the model."""
if isinstance(llm_or_chain_factory, BaseLanguageModel):
_validate_example_inputs_for_language_model(example, input_mapper)
_validate_example_inputs_for_language_model(example)
else:
chain = llm_or_chain_factory()
if isinstance(chain, Chain):
# Otherwise it's a runnable
_validate_example_inputs_for_chain(example, chain, input_mapper)
_validate_example_inputs_for_chain(example, chain)
elif isinstance(chain, Runnable):
logger.debug(f"Skipping input validation for {chain}")
@@ -591,7 +576,6 @@ async def _arun_llm(
*,
tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[str, BaseMessage]:
"""Asynchronously run the language model.
@@ -600,7 +584,6 @@ async def _arun_llm(
inputs: The input dictionary.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map inputs to the expected format.
Returns:
The LLMResult or ChatResult.
@@ -608,36 +591,16 @@ async def _arun_llm(
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if input_mapper is not None:
prompt_or_messages = input_mapper(inputs)
if isinstance(prompt_or_messages, str):
return await llm.apredict(
prompt_or_messages, callbacks=callbacks, tags=tags
)
elif isinstance(prompt_or_messages, list) and all(
isinstance(msg, BaseMessage) for msg in prompt_or_messages
):
return await llm.apredict_messages(
prompt_or_messages, callbacks=callbacks, tags=tags
)
else:
raise InputFormatError(
"Input mapper returned invalid format"
f" {prompt_or_messages}"
"\nExpected a single string or list of chat messages."
)
else:
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.apredict(
prompt, callbacks=callbacks, tags=tags
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.apredict_messages(
messages, callbacks=callbacks, tags=tags
)
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.apredict(
prompt, callbacks=callbacks, tags=tags
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.apredict_messages(
messages, callbacks=callbacks, tags=tags
)
return llm_output
@@ -647,19 +610,17 @@ async def _arun_chain(
callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str]:
"""Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
if isinstance(inputs, dict) and len(inputs) == 1:
val = next(iter(inputs.values()))
output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
output = await chain.acall(inputs, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.ainvoke(inputs_, config=runnable_config)
output = await chain.ainvoke(inputs, config=runnable_config)
return output
@@ -668,7 +629,6 @@ async def _arun_llm_or_chain(
config: RunnableConfig,
*,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]:
"""Asynchronously run the Chain or language model.
@@ -693,7 +653,6 @@ async def _arun_llm_or_chain(
example.inputs,
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
@@ -702,7 +661,6 @@ async def _arun_llm_or_chain(
example.inputs,
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
)
result = output
except Exception as e:
@@ -775,19 +733,17 @@ def _run_chain(
callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[Dict, str]:
"""Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
if isinstance(inputs, dict) and len(inputs) == 1:
val = next(iter(inputs.values()))
output = chain(val, callbacks=callbacks, tags=tags)
else:
output = chain(inputs_, callbacks=callbacks, tags=tags)
output = chain(inputs, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain.invoke(inputs_, config=runnable_config)
output = chain.invoke(inputs, config=runnable_config)
return output
@@ -796,7 +752,6 @@ def _run_llm_or_chain(
config: RunnableConfig,
*,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]:
"""
Run the Chain or language model synchronously.
@@ -822,7 +777,6 @@ def _run_llm_or_chain(
example.inputs,
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
@@ -831,7 +785,6 @@ def _run_llm_or_chain(
example.inputs,
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
)
result = output
except Exception as e:
@@ -879,9 +832,7 @@ def _prepare_run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
config: Optional[RunnableConfig] = None,
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
project_name = project_name or name_generation.random_name()
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
@@ -891,8 +842,10 @@ def _prepare_run_on_dataset(
run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type
)
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
_validate_example_inputs(examples[0], wrapped_model)
progress_bar = progress.ProgressBarCallback(len(examples))
run_config = config or RunnableConfig(max_concurrency=5)
_callbacks = run_config["callbacks"] if "callbacks" in run_config else []
configs = [
RunnableConfig(
callbacks=[
@@ -909,9 +862,9 @@ def _prepare_run_on_dataset(
example_id=example.id,
),
progress_bar,
],
tags=tags or [],
max_concurrency=concurrency_level,
]
+ _callbacks,
**run_config,
)
for example in examples
]
@@ -947,16 +900,14 @@ def _collect_test_results(
async def arun_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
concurrency_level: int = 5,
config: Optional[RunnableConfig] = None,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
@@ -964,24 +915,22 @@ async def arun_on_dataset(
and store traces to the specified project name.
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Either a runnable or a constructor for a runnable.
If your runnable contains state (e.g. an agent or
chain with memory), you should pass in a constructor
function that returns a new runnable instance for each
example row in the dataset.
client: LangSmith client to use to read the dataset, and to
log feedback and run traces.
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently.
evaluation: Optional evaluation configuration to specify which
evaluators to run on the results of the chain and which
key(s) to consider during evaluation.
config: Optional configuration to specify concurrency, tags,
or other information to apply to the chain runs.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
Defaults to a friendly name.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns:
A dictionary containing the run's project name and the
resulting model outputs.
@@ -993,7 +942,6 @@ async def arun_on_dataset(
.. code-block:: python
from langsmith import Client
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset
@@ -1020,11 +968,9 @@ async def arun_on_dataset(
]
)
client = Client()
await arun_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
dataset_name="<my_dataset_name>",
llm_or_chain_config=construct_chain,
evaluation=evaluation_config,
)
@@ -1060,7 +1006,6 @@ async def arun_on_dataset(
)
await arun_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
evaluation=evaluation_config,
@@ -1073,15 +1018,14 @@ async def arun_on_dataset(
f"{kwargs.keys()}.",
DeprecationWarning,
)
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
dataset_name,
llm_or_chain_factory,
project_name,
evaluation,
tags,
input_mapper,
concurrency_level,
config=config,
)
batch_results = await runnable_utils.gather_with_concurrency(
@@ -1090,7 +1034,6 @@ async def arun_on_dataset(
functools.partial(
_arun_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
@@ -1108,16 +1051,14 @@ async def arun_on_dataset(
def run_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
concurrency_level: int = 5,
config: Optional[RunnableConfig] = None,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
@@ -1135,7 +1076,7 @@ def run_on_dataset(
results of the chain
concurrency_level: The number of async tasks to run concurrently.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
Defaults to a friendly name
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example
@@ -1155,7 +1096,6 @@ def run_on_dataset(
.. code-block:: python
from langsmith import Client
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset
@@ -1182,12 +1122,11 @@ def run_on_dataset(
]
)
client = Client()
run_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
dataset_name="<my_dataset_name>",
llm_or_chain_factory=construct_chain,
evaluation=evaluation_config,
verbose=True,
)
You can also create custom evaluators by subclassing the
@@ -1222,7 +1161,6 @@ def run_on_dataset(
)
run_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
evaluation=evaluation_config,
@@ -1235,15 +1173,14 @@ def run_on_dataset(
f"{kwargs.keys()}.",
DeprecationWarning,
)
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
dataset_name,
llm_or_chain_factory,
project_name,
evaluation,
tags,
input_mapper,
concurrency_level,
config=config,
)
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
@@ -1251,7 +1188,6 @@ def run_on_dataset(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,