mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
Add Input Mapper in run_on_dataset (#6894)
If you create a dataset from runs and run the same chain or llm on it later, it usually works great. If you have an agent dataset and want to run a different agent on it, or have more complex schema, it's hard for us to automatically map these values every time. This PR lets you pass in an input_mapper function that converts the example inputs to whatever format your model expects
This commit is contained in:
parent
76d03f398d
commit
429f4dbe4d
@ -139,6 +139,7 @@ async def _arun_llm(
|
|||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
"""
|
"""
|
||||||
Asynchronously run the language model.
|
Asynchronously run the language model.
|
||||||
@ -148,6 +149,7 @@ async def _arun_llm(
|
|||||||
inputs: The input dictionary.
|
inputs: The input dictionary.
|
||||||
tags: Optional tags to add to the run.
|
tags: Optional tags to add to the run.
|
||||||
callbacks: Optional callbacks to use during the run.
|
callbacks: Optional callbacks to use during the run.
|
||||||
|
input_mapper: Optional function to map inputs to the expected format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The LLMResult or ChatResult.
|
The LLMResult or ChatResult.
|
||||||
@ -155,7 +157,13 @@ async def _arun_llm(
|
|||||||
ValueError: If the LLM type is unsupported.
|
ValueError: If the LLM type is unsupported.
|
||||||
InputFormatError: If the input format is invalid.
|
InputFormatError: If the input format is invalid.
|
||||||
"""
|
"""
|
||||||
if isinstance(llm, BaseLLM):
|
if input_mapper is not None:
|
||||||
|
if not isinstance(llm, (BaseLLM, BaseChatModel)):
|
||||||
|
raise ValueError(f"Unsupported LLM type {type(llm).__name__}")
|
||||||
|
llm_output = await llm.agenerate(
|
||||||
|
input_mapper(inputs), callbacks=callbacks, tags=tags
|
||||||
|
)
|
||||||
|
elif isinstance(llm, BaseLLM):
|
||||||
try:
|
try:
|
||||||
llm_prompts = _get_prompts(inputs)
|
llm_prompts = _get_prompts(inputs)
|
||||||
llm_output = await llm.agenerate(
|
llm_output = await llm.agenerate(
|
||||||
@ -191,6 +199,7 @@ async def _arun_llm_or_chain(
|
|||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
"""
|
"""
|
||||||
Asynchronously run the Chain or language model.
|
Asynchronously run the Chain or language model.
|
||||||
@ -201,6 +210,7 @@ async def _arun_llm_or_chain(
|
|||||||
n_repetitions: The number of times to run the model on each example.
|
n_repetitions: The number of times to run the model on each example.
|
||||||
tags: Optional tags to add to the run.
|
tags: Optional tags to add to the run.
|
||||||
callbacks: Optional callbacks to use during the run.
|
callbacks: Optional callbacks to use during the run.
|
||||||
|
input_mapper: Optional function to map the input to the expected format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of outputs.
|
A list of outputs.
|
||||||
@ -223,12 +233,16 @@ async def _arun_llm_or_chain(
|
|||||||
example.inputs,
|
example.inputs,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
input_mapper=input_mapper,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chain = llm_or_chain_factory()
|
chain = llm_or_chain_factory()
|
||||||
inputs_ = example.inputs
|
if input_mapper is not None:
|
||||||
if len(inputs_) == 1:
|
inputs_ = input_mapper(example.inputs)
|
||||||
inputs_ = next(iter(inputs_.values()))
|
else:
|
||||||
|
inputs_ = example.inputs
|
||||||
|
if len(inputs_) == 1:
|
||||||
|
inputs_ = next(iter(inputs_.values()))
|
||||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -333,6 +347,7 @@ async def arun_on_examples(
|
|||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Asynchronously run the chain on examples and store traces
|
Asynchronously run the chain on examples and store traces
|
||||||
@ -354,6 +369,11 @@ async def arun_on_examples(
|
|||||||
client will be created using the credentials in the environment.
|
client will be created using the credentials in the environment.
|
||||||
tags: Tags to add to each run in the project.
|
tags: Tags to add to each run in the project.
|
||||||
run_evaluators: Evaluators to run on the results of the chain.
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
input_mapper: function to map to the inputs dictionary from an Example
|
||||||
|
to the format expected by the model to be evaluated. This is useful if
|
||||||
|
your model needs to deserialize more complex schema or if your dataset
|
||||||
|
has inputs with keys that differ from what is expected by your chain
|
||||||
|
or agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary mapping example ids to the model outputs.
|
||||||
@ -377,6 +397,7 @@ async def arun_on_examples(
|
|||||||
num_repetitions,
|
num_repetitions,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
input_mapper=input_mapper,
|
||||||
)
|
)
|
||||||
results[str(example.id)] = result
|
results[str(example.id)] = result
|
||||||
job_state["num_processed"] += 1
|
job_state["num_processed"] += 1
|
||||||
@ -407,6 +428,7 @@ def run_llm(
|
|||||||
callbacks: Callbacks,
|
callbacks: Callbacks,
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
"""
|
"""
|
||||||
Run the language model on the example.
|
Run the language model on the example.
|
||||||
@ -416,14 +438,18 @@ def run_llm(
|
|||||||
inputs: The input dictionary.
|
inputs: The input dictionary.
|
||||||
callbacks: The callbacks to use during the run.
|
callbacks: The callbacks to use during the run.
|
||||||
tags: Optional tags to add to the run.
|
tags: Optional tags to add to the run.
|
||||||
|
input_mapper: function to map to the inputs dictionary from an Example
|
||||||
Returns:
|
Returns:
|
||||||
The LLMResult or ChatResult.
|
The LLMResult or ChatResult.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the LLM type is unsupported.
|
ValueError: If the LLM type is unsupported.
|
||||||
InputFormatError: If the input format is invalid.
|
InputFormatError: If the input format is invalid.
|
||||||
"""
|
"""
|
||||||
if isinstance(llm, BaseLLM):
|
if input_mapper is not None:
|
||||||
|
if not isinstance(llm, (BaseLLM, BaseChatModel)):
|
||||||
|
raise ValueError(f"Unsupported LLM type {type(llm).__name__}")
|
||||||
|
llm_output = llm.generate(input_mapper(inputs), callbacks=callbacks, tags=tags)
|
||||||
|
elif isinstance(llm, BaseLLM):
|
||||||
try:
|
try:
|
||||||
llm_prompts = _get_prompts(inputs)
|
llm_prompts = _get_prompts(inputs)
|
||||||
llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
|
llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
|
||||||
@ -455,6 +481,7 @@ def run_llm_or_chain(
|
|||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
"""
|
"""
|
||||||
Run the Chain or language model synchronously.
|
Run the Chain or language model synchronously.
|
||||||
@ -483,13 +510,20 @@ def run_llm_or_chain(
|
|||||||
try:
|
try:
|
||||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||||
output: Any = run_llm(
|
output: Any = run_llm(
|
||||||
llm_or_chain_factory, example.inputs, callbacks, tags=tags
|
llm_or_chain_factory,
|
||||||
|
example.inputs,
|
||||||
|
callbacks,
|
||||||
|
tags=tags,
|
||||||
|
input_mapper=input_mapper,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chain = llm_or_chain_factory()
|
chain = llm_or_chain_factory()
|
||||||
inputs_ = example.inputs
|
if input_mapper is not None:
|
||||||
if len(inputs_) == 1:
|
inputs_ = input_mapper(example.inputs)
|
||||||
inputs_ = next(iter(inputs_.values()))
|
else:
|
||||||
|
inputs_ = example.inputs
|
||||||
|
if len(inputs_) == 1:
|
||||||
|
inputs_ = next(iter(inputs_.values()))
|
||||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -512,6 +546,7 @@ def run_on_examples(
|
|||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run the Chain or language model on examples and store
|
Run the Chain or language model on examples and store
|
||||||
@ -532,6 +567,11 @@ def run_on_examples(
|
|||||||
will be created using the credentials in the environment.
|
will be created using the credentials in the environment.
|
||||||
tags: Tags to add to each run in the project.
|
tags: Tags to add to each run in the project.
|
||||||
run_evaluators: Evaluators to run on the results of the chain.
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
input_mapper: A function to map to the inputs dictionary from an Example
|
||||||
|
to the format expected by the model to be evaluated. This is useful if
|
||||||
|
your model needs to deserialize more complex schema or if your dataset
|
||||||
|
has inputs with keys that differ from what is expected by your chain
|
||||||
|
or agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary mapping example ids to the model outputs.
|
||||||
@ -552,6 +592,7 @@ def run_on_examples(
|
|||||||
num_repetitions,
|
num_repetitions,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
input_mapper=input_mapper,
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{i+1} processed", flush=True, end="\r")
|
print(f"{i+1} processed", flush=True, end="\r")
|
||||||
@ -599,6 +640,7 @@ async def arun_on_dataset(
|
|||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Asynchronously run the Chain or language model on a dataset
|
Asynchronously run the Chain or language model on a dataset
|
||||||
@ -620,7 +662,11 @@ async def arun_on_dataset(
|
|||||||
client will be created using the credentials in the environment.
|
client will be created using the credentials in the environment.
|
||||||
tags: Tags to add to each run in the project.
|
tags: Tags to add to each run in the project.
|
||||||
run_evaluators: Evaluators to run on the results of the chain.
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
input_mapper: A function to map to the inputs dictionary from an Example
|
||||||
|
to the format expected by the model to be evaluated. This is useful if
|
||||||
|
your model needs to deserialize more complex schema or if your dataset
|
||||||
|
has inputs with keys that differ from what is expected by your chain
|
||||||
|
or agent.
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the run's project name and the resulting model outputs.
|
A dictionary containing the run's project name and the resulting model outputs.
|
||||||
"""
|
"""
|
||||||
@ -638,6 +684,7 @@ async def arun_on_dataset(
|
|||||||
client=client_,
|
client=client_,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
run_evaluators=run_evaluators,
|
run_evaluators=run_evaluators,
|
||||||
|
input_mapper=input_mapper,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"project_name": project_name,
|
"project_name": project_name,
|
||||||
@ -655,6 +702,7 @@ def run_on_dataset(
|
|||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run the Chain or language model on a dataset and store traces
|
Run the Chain or language model on a dataset and store traces
|
||||||
@ -676,6 +724,11 @@ def run_on_dataset(
|
|||||||
will be created using the credentials in the environment.
|
will be created using the credentials in the environment.
|
||||||
tags: Tags to add to each run in the project.
|
tags: Tags to add to each run in the project.
|
||||||
run_evaluators: Evaluators to run on the results of the chain.
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
input_mapper: A function to map to the inputs dictionary from an Example
|
||||||
|
to the format expected by the model to be evaluated. This is useful if
|
||||||
|
your model needs to deserialize more complex schema or if your dataset
|
||||||
|
has inputs with keys that differ from what is expected by your chain
|
||||||
|
or agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the run's project name and the resulting model outputs.
|
A dictionary containing the run's project name and the resulting model outputs.
|
||||||
@ -693,6 +746,7 @@ def run_on_dataset(
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
run_evaluators=run_evaluators,
|
run_evaluators=run_evaluators,
|
||||||
client=client_,
|
client=client_,
|
||||||
|
input_mapper=input_mapper,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"project_name": project_name,
|
"project_name": project_name,
|
||||||
|
@ -10,13 +10,16 @@ from langchainplus_sdk.schemas import Dataset, Example
|
|||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.transform import TransformChain
|
||||||
from langchain.client.runner_utils import (
|
from langchain.client.runner_utils import (
|
||||||
InputFormatError,
|
InputFormatError,
|
||||||
_get_messages,
|
_get_messages,
|
||||||
_get_prompts,
|
_get_prompts,
|
||||||
arun_on_dataset,
|
arun_on_dataset,
|
||||||
run_llm,
|
run_llm,
|
||||||
|
run_llm_or_chain,
|
||||||
)
|
)
|
||||||
|
from langchain.schema import LLMResult
|
||||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
@ -75,6 +78,57 @@ def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
|||||||
_get_prompts(inputs)
|
_get_prompts(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||||
|
example = Example(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
created_at=_CREATED_AT,
|
||||||
|
inputs={"the wrong input": "1", "another key": "2"},
|
||||||
|
outputs={"output": "2"},
|
||||||
|
dataset_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_val(inputs: dict) -> dict:
|
||||||
|
assert "the right input" in inputs
|
||||||
|
return {"output": "2"}
|
||||||
|
|
||||||
|
mock_chain = TransformChain(
|
||||||
|
input_variables=["the right input"],
|
||||||
|
output_variables=["output"],
|
||||||
|
transform=run_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
def input_mapper(inputs: dict) -> dict:
|
||||||
|
assert "the wrong input" in inputs
|
||||||
|
return {"the right input": inputs["the wrong input"]}
|
||||||
|
|
||||||
|
result = run_llm_or_chain(
|
||||||
|
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
|
||||||
|
)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0] == {"output": "2", "the right input": "1"}
|
||||||
|
bad_result = run_llm_or_chain(
|
||||||
|
example,
|
||||||
|
lambda: mock_chain,
|
||||||
|
n_repetitions=1,
|
||||||
|
)
|
||||||
|
assert len(bad_result) == 1
|
||||||
|
assert "Error" in bad_result[0]
|
||||||
|
|
||||||
|
# Try with LLM
|
||||||
|
def llm_input_mapper(inputs: dict) -> List[str]:
|
||||||
|
assert "the wrong input" in inputs
|
||||||
|
return ["the right input"]
|
||||||
|
|
||||||
|
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
|
||||||
|
result = run_llm_or_chain(
|
||||||
|
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
|
||||||
|
)
|
||||||
|
assert len(result) == 1
|
||||||
|
llm_result = result[0]
|
||||||
|
assert isinstance(llm_result, LLMResult)
|
||||||
|
assert llm_result.generations[0][0].text == "somenumber"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"inputs",
|
"inputs",
|
||||||
[
|
[
|
||||||
@ -171,6 +225,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
callbacks: Optional[Any] = None,
|
callbacks: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
return [
|
return [
|
||||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||||
|
Loading…
Reference in New Issue
Block a user