mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
Compare commits
2 Commits
langchain-
...
wfh/implic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d180208915 | ||
|
|
625e598111 |
@@ -278,35 +278,21 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
|
||||
## Shared data validation utilities
|
||||
def _validate_example_inputs_for_language_model(
|
||||
first_example: Example,
|
||||
input_mapper: Optional[Callable[[Dict], Any]],
|
||||
) -> None:
|
||||
if input_mapper:
|
||||
prompt_input = input_mapper(first_example.inputs)
|
||||
if not isinstance(prompt_input, str) and not (
|
||||
isinstance(prompt_input, list)
|
||||
and all(isinstance(msg, BaseMessage) for msg in prompt_input)
|
||||
):
|
||||
raise InputFormatError(
|
||||
"When using an input_mapper to prepare dataset example inputs"
|
||||
" for an LLM or chat model, the output must a single string or"
|
||||
" a list of chat messages."
|
||||
f"\nGot: {prompt_input} of type {type(prompt_input)}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
_get_prompt(first_example.inputs)
|
||||
except InputFormatError:
|
||||
try:
|
||||
_get_prompt(first_example.inputs)
|
||||
_get_messages(first_example.inputs)
|
||||
except InputFormatError:
|
||||
try:
|
||||
_get_messages(first_example.inputs)
|
||||
except InputFormatError:
|
||||
raise InputFormatError(
|
||||
"Example inputs do not match language model input format. "
|
||||
"Expected a dictionary with messages or a single prompt."
|
||||
f" Got: {first_example.inputs}"
|
||||
" Please update your dataset OR provide an input_mapper"
|
||||
" to convert the example.inputs to a compatible format"
|
||||
" for the llm or chat model you wish to evaluate."
|
||||
)
|
||||
raise InputFormatError(
|
||||
"Example inputs do not match language model input format. "
|
||||
"Expected a dictionary with messages or a single prompt."
|
||||
f" Got: {first_example.inputs}"
|
||||
" Please update your dataset OR provide an input_mapper"
|
||||
" to convert the example.inputs to a compatible format"
|
||||
" for the llm or chat model you wish to evaluate."
|
||||
)
|
||||
|
||||
|
||||
def _validate_example_inputs_for_chain(
|
||||
@@ -349,16 +335,15 @@ def _validate_example_inputs_for_chain(
|
||||
def _validate_example_inputs(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]],
|
||||
) -> None:
|
||||
"""Validate that the example inputs are valid for the model."""
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
_validate_example_inputs_for_language_model(example, input_mapper)
|
||||
_validate_example_inputs_for_language_model(example)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
if isinstance(chain, Chain):
|
||||
# Otherwise it's a runnable
|
||||
_validate_example_inputs_for_chain(example, chain, input_mapper)
|
||||
_validate_example_inputs_for_chain(example, chain)
|
||||
elif isinstance(chain, Runnable):
|
||||
logger.debug(f"Skipping input validation for {chain}")
|
||||
|
||||
@@ -591,7 +576,6 @@ async def _arun_llm(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[str, BaseMessage]:
|
||||
"""Asynchronously run the language model.
|
||||
|
||||
@@ -600,7 +584,6 @@ async def _arun_llm(
|
||||
inputs: The input dictionary.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
input_mapper: Optional function to map inputs to the expected format.
|
||||
|
||||
Returns:
|
||||
The LLMResult or ChatResult.
|
||||
@@ -608,36 +591,16 @@ async def _arun_llm(
|
||||
ValueError: If the LLM type is unsupported.
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if input_mapper is not None:
|
||||
prompt_or_messages = input_mapper(inputs)
|
||||
if isinstance(prompt_or_messages, str):
|
||||
return await llm.apredict(
|
||||
prompt_or_messages, callbacks=callbacks, tags=tags
|
||||
)
|
||||
elif isinstance(prompt_or_messages, list) and all(
|
||||
isinstance(msg, BaseMessage) for msg in prompt_or_messages
|
||||
):
|
||||
return await llm.apredict_messages(
|
||||
prompt_or_messages, callbacks=callbacks, tags=tags
|
||||
)
|
||||
else:
|
||||
raise InputFormatError(
|
||||
"Input mapper returned invalid format"
|
||||
f" {prompt_or_messages}"
|
||||
"\nExpected a single string or list of chat messages."
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
prompt = _get_prompt(inputs)
|
||||
llm_output: Union[str, BaseMessage] = await llm.apredict(
|
||||
prompt, callbacks=callbacks, tags=tags
|
||||
)
|
||||
except InputFormatError:
|
||||
messages = _get_messages(inputs)
|
||||
llm_output = await llm.apredict_messages(
|
||||
messages, callbacks=callbacks, tags=tags
|
||||
)
|
||||
try:
|
||||
prompt = _get_prompt(inputs)
|
||||
llm_output: Union[str, BaseMessage] = await llm.apredict(
|
||||
prompt, callbacks=callbacks, tags=tags
|
||||
)
|
||||
except InputFormatError:
|
||||
messages = _get_messages(inputs)
|
||||
llm_output = await llm.apredict_messages(
|
||||
messages, callbacks=callbacks, tags=tags
|
||||
)
|
||||
return llm_output
|
||||
|
||||
|
||||
@@ -647,19 +610,17 @@ async def _arun_chain(
|
||||
callbacks: Callbacks,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[dict, str]:
|
||||
"""Run a chain asynchronously on inputs."""
|
||||
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
|
||||
if isinstance(chain, Chain):
|
||||
if isinstance(inputs_, dict) and len(inputs_) == 1:
|
||||
val = next(iter(inputs_.values()))
|
||||
if isinstance(inputs, dict) and len(inputs) == 1:
|
||||
val = next(iter(inputs.values()))
|
||||
output = await chain.acall(val, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||
output = await chain.acall(inputs, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
output = await chain.ainvoke(inputs_, config=runnable_config)
|
||||
output = await chain.ainvoke(inputs, config=runnable_config)
|
||||
return output
|
||||
|
||||
|
||||
@@ -668,7 +629,6 @@ async def _arun_llm_or_chain(
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[dict, str, LLMResult, ChatResult]:
|
||||
"""Asynchronously run the Chain or language model.
|
||||
|
||||
@@ -693,7 +653,6 @@ async def _arun_llm_or_chain(
|
||||
example.inputs,
|
||||
tags=config["tags"],
|
||||
callbacks=config["callbacks"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
@@ -702,7 +661,6 @@ async def _arun_llm_or_chain(
|
||||
example.inputs,
|
||||
tags=config["tags"],
|
||||
callbacks=config["callbacks"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
result = output
|
||||
except Exception as e:
|
||||
@@ -775,19 +733,17 @@ def _run_chain(
|
||||
callbacks: Callbacks,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[Dict, str]:
|
||||
"""Run a chain on inputs."""
|
||||
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
|
||||
if isinstance(chain, Chain):
|
||||
if isinstance(inputs_, dict) and len(inputs_) == 1:
|
||||
val = next(iter(inputs_.values()))
|
||||
if isinstance(inputs, dict) and len(inputs) == 1:
|
||||
val = next(iter(inputs.values()))
|
||||
output = chain(val, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||
output = chain(inputs, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
output = chain.invoke(inputs_, config=runnable_config)
|
||||
output = chain.invoke(inputs, config=runnable_config)
|
||||
return output
|
||||
|
||||
|
||||
@@ -796,7 +752,6 @@ def _run_llm_or_chain(
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[dict, str, LLMResult, ChatResult]:
|
||||
"""
|
||||
Run the Chain or language model synchronously.
|
||||
@@ -822,7 +777,6 @@ def _run_llm_or_chain(
|
||||
example.inputs,
|
||||
config["callbacks"],
|
||||
tags=config["tags"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
@@ -831,7 +785,6 @@ def _run_llm_or_chain(
|
||||
example.inputs,
|
||||
config["callbacks"],
|
||||
tags=config["tags"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
result = output
|
||||
except Exception as e:
|
||||
@@ -879,9 +832,7 @@ def _prepare_run_on_dataset(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
concurrency_level: int = 5,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
|
||||
project_name = project_name or name_generation.random_name()
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
@@ -891,8 +842,10 @@ def _prepare_run_on_dataset(
|
||||
run_evaluators = _setup_evaluation(
|
||||
wrapped_model, examples, evaluation, dataset.data_type
|
||||
)
|
||||
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
|
||||
_validate_example_inputs(examples[0], wrapped_model)
|
||||
progress_bar = progress.ProgressBarCallback(len(examples))
|
||||
run_config = config or RunnableConfig(max_concurrency=5)
|
||||
_callbacks = run_config["callbacks"] if "callbacks" in run_config else []
|
||||
configs = [
|
||||
RunnableConfig(
|
||||
callbacks=[
|
||||
@@ -909,9 +862,9 @@ def _prepare_run_on_dataset(
|
||||
example_id=example.id,
|
||||
),
|
||||
progress_bar,
|
||||
],
|
||||
tags=tags or [],
|
||||
max_concurrency=concurrency_level,
|
||||
]
|
||||
+ _callbacks,
|
||||
**run_config,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
@@ -947,16 +900,14 @@ def _collect_test_results(
|
||||
|
||||
|
||||
async def arun_on_dataset(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
client: Optional[Client] = None,
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -964,24 +915,22 @@ async def arun_on_dataset(
|
||||
and store traces to the specified project name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain_factory: Either a runnable or a constructor for a runnable.
|
||||
If your runnable contains state (e.g. an agent or
|
||||
chain with memory), you should pass in a constructor
|
||||
function that returns a new runnable instance for each
|
||||
example row in the dataset.
|
||||
client: LangSmith client to use to read the dataset, and to
|
||||
log feedback and run traces.
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Optional evaluation configuration to use when evaluating
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
evaluation: Optional evaluation configuration to specify which
|
||||
evaluators to run on the results of the chain and which
|
||||
key(s) to consider during evaluation.
|
||||
config: Optional configuration to specify concurrency, tags,
|
||||
or other information to apply to the chain runs.
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
Defaults to a friendly name.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to each run in the project.
|
||||
input_mapper: A function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the run's project name and the
|
||||
resulting model outputs.
|
||||
@@ -993,7 +942,6 @@ async def arun_on_dataset(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langsmith import Client
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset
|
||||
@@ -1020,11 +968,9 @@ async def arun_on_dataset(
|
||||
]
|
||||
)
|
||||
|
||||
client = Client()
|
||||
await arun_on_dataset(
|
||||
client,
|
||||
"<my_dataset_name>",
|
||||
construct_chain,
|
||||
dataset_name="<my_dataset_name>",
|
||||
llm_or_chain_config=construct_chain,
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
|
||||
@@ -1060,7 +1006,6 @@ async def arun_on_dataset(
|
||||
)
|
||||
|
||||
await arun_on_dataset(
|
||||
client,
|
||||
"<my_dataset_name>",
|
||||
construct_chain,
|
||||
evaluation=evaluation_config,
|
||||
@@ -1073,15 +1018,14 @@ async def arun_on_dataset(
|
||||
f"{kwargs.keys()}.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
client = client or Client()
|
||||
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
||||
client,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
project_name,
|
||||
evaluation,
|
||||
tags,
|
||||
input_mapper,
|
||||
concurrency_level,
|
||||
config=config,
|
||||
)
|
||||
|
||||
batch_results = await runnable_utils.gather_with_concurrency(
|
||||
@@ -1090,7 +1034,6 @@ async def arun_on_dataset(
|
||||
functools.partial(
|
||||
_arun_llm_or_chain,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
),
|
||||
examples,
|
||||
configs,
|
||||
@@ -1108,16 +1051,14 @@ async def arun_on_dataset(
|
||||
|
||||
|
||||
def run_on_dataset(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
client: Optional[Client] = None,
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -1135,7 +1076,7 @@ def run_on_dataset(
|
||||
results of the chain
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
Defaults to a friendly name
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to each run in the project.
|
||||
input_mapper: A function to map to the inputs dictionary from an Example
|
||||
@@ -1155,7 +1096,6 @@ def run_on_dataset(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langsmith import Client
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset
|
||||
@@ -1182,12 +1122,11 @@ def run_on_dataset(
|
||||
]
|
||||
)
|
||||
|
||||
client = Client()
|
||||
run_on_dataset(
|
||||
client,
|
||||
"<my_dataset_name>",
|
||||
construct_chain,
|
||||
dataset_name="<my_dataset_name>",
|
||||
llm_or_chain_factory=construct_chain,
|
||||
evaluation=evaluation_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
You can also create custom evaluators by subclassing the
|
||||
@@ -1222,7 +1161,6 @@ def run_on_dataset(
|
||||
)
|
||||
|
||||
run_on_dataset(
|
||||
client,
|
||||
"<my_dataset_name>",
|
||||
construct_chain,
|
||||
evaluation=evaluation_config,
|
||||
@@ -1235,15 +1173,14 @@ def run_on_dataset(
|
||||
f"{kwargs.keys()}.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
client = client or Client()
|
||||
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
||||
client,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
project_name,
|
||||
evaluation,
|
||||
tags,
|
||||
input_mapper,
|
||||
concurrency_level,
|
||||
config=config,
|
||||
)
|
||||
with runnable_config.get_executor_for_config(configs[0]) as executor:
|
||||
batch_results = list(
|
||||
@@ -1251,7 +1188,6 @@ def run_on_dataset(
|
||||
functools.partial(
|
||||
_run_llm_or_chain,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
),
|
||||
examples,
|
||||
configs,
|
||||
|
||||
Reference in New Issue
Block a user