mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
Add Input Mapper in run_on_dataset (#6894)
If you create a dataset from runs and run the same chain or llm on it later, it usually works great. If you have an agent dataset and want to run a different agent on it, or have more complex schema, it's hard for us to automatically map these values every time. This PR lets you pass in an input_mapper function that converts the example inputs to whatever format your model expects
This commit is contained in:
parent
76d03f398d
commit
429f4dbe4d
@ -139,6 +139,7 @@ async def _arun_llm(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
"""
|
||||
Asynchronously run the language model.
|
||||
@ -148,6 +149,7 @@ async def _arun_llm(
|
||||
inputs: The input dictionary.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
input_mapper: Optional function to map inputs to the expected format.
|
||||
|
||||
Returns:
|
||||
The LLMResult or ChatResult.
|
||||
@ -155,7 +157,13 @@ async def _arun_llm(
|
||||
ValueError: If the LLM type is unsupported.
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if isinstance(llm, BaseLLM):
|
||||
if input_mapper is not None:
|
||||
if not isinstance(llm, (BaseLLM, BaseChatModel)):
|
||||
raise ValueError(f"Unsupported LLM type {type(llm).__name__}")
|
||||
llm_output = await llm.agenerate(
|
||||
input_mapper(inputs), callbacks=callbacks, tags=tags
|
||||
)
|
||||
elif isinstance(llm, BaseLLM):
|
||||
try:
|
||||
llm_prompts = _get_prompts(inputs)
|
||||
llm_output = await llm.agenerate(
|
||||
@ -191,6 +199,7 @@ async def _arun_llm_or_chain(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""
|
||||
Asynchronously run the Chain or language model.
|
||||
@ -201,6 +210,7 @@ async def _arun_llm_or_chain(
|
||||
n_repetitions: The number of times to run the model on each example.
|
||||
tags: Optional tags to add to the run.
|
||||
callbacks: Optional callbacks to use during the run.
|
||||
input_mapper: Optional function to map the input to the expected format.
|
||||
|
||||
Returns:
|
||||
A list of outputs.
|
||||
@ -223,12 +233,16 @@ async def _arun_llm_or_chain(
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
inputs_ = example.inputs
|
||||
if len(inputs_) == 1:
|
||||
inputs_ = next(iter(inputs_.values()))
|
||||
if input_mapper is not None:
|
||||
inputs_ = input_mapper(example.inputs)
|
||||
else:
|
||||
inputs_ = example.inputs
|
||||
if len(inputs_) == 1:
|
||||
inputs_ = next(iter(inputs_.values()))
|
||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
@ -333,6 +347,7 @@ async def arun_on_examples(
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously run the chain on examples and store traces
|
||||
@ -354,6 +369,11 @@ async def arun_on_examples(
|
||||
client will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
input_mapper: function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
@ -377,6 +397,7 @@ async def arun_on_examples(
|
||||
num_repetitions,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
job_state["num_processed"] += 1
|
||||
@ -407,6 +428,7 @@ def run_llm(
|
||||
callbacks: Callbacks,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
"""
|
||||
Run the language model on the example.
|
||||
@ -416,14 +438,18 @@ def run_llm(
|
||||
inputs: The input dictionary.
|
||||
callbacks: The callbacks to use during the run.
|
||||
tags: Optional tags to add to the run.
|
||||
|
||||
input_mapper: function to map to the inputs dictionary from an Example
|
||||
Returns:
|
||||
The LLMResult or ChatResult.
|
||||
Raises:
|
||||
ValueError: If the LLM type is unsupported.
|
||||
InputFormatError: If the input format is invalid.
|
||||
"""
|
||||
if isinstance(llm, BaseLLM):
|
||||
if input_mapper is not None:
|
||||
if not isinstance(llm, (BaseLLM, BaseChatModel)):
|
||||
raise ValueError(f"Unsupported LLM type {type(llm).__name__}")
|
||||
llm_output = llm.generate(input_mapper(inputs), callbacks=callbacks, tags=tags)
|
||||
elif isinstance(llm, BaseLLM):
|
||||
try:
|
||||
llm_prompts = _get_prompts(inputs)
|
||||
llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
|
||||
@ -455,6 +481,7 @@ def run_llm_or_chain(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""
|
||||
Run the Chain or language model synchronously.
|
||||
@ -483,13 +510,20 @@ def run_llm_or_chain(
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = run_llm(
|
||||
llm_or_chain_factory, example.inputs, callbacks, tags=tags
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
inputs_ = example.inputs
|
||||
if len(inputs_) == 1:
|
||||
inputs_ = next(iter(inputs_.values()))
|
||||
if input_mapper is not None:
|
||||
inputs_ = input_mapper(example.inputs)
|
||||
else:
|
||||
inputs_ = example.inputs
|
||||
if len(inputs_) == 1:
|
||||
inputs_ = next(iter(inputs_.values()))
|
||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
@ -512,6 +546,7 @@ def run_on_examples(
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the Chain or language model on examples and store
|
||||
@ -532,6 +567,11 @@ def run_on_examples(
|
||||
will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
input_mapper: A function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
@ -552,6 +592,7 @@ def run_on_examples(
|
||||
num_repetitions,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
@ -599,6 +640,7 @@ async def arun_on_dataset(
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously run the Chain or language model on a dataset
|
||||
@ -620,7 +662,11 @@ async def arun_on_dataset(
|
||||
client will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
|
||||
input_mapper: A function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
Returns:
|
||||
A dictionary containing the run's project name and the resulting model outputs.
|
||||
"""
|
||||
@ -638,6 +684,7 @@ async def arun_on_dataset(
|
||||
client=client_,
|
||||
tags=tags,
|
||||
run_evaluators=run_evaluators,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
@ -655,6 +702,7 @@ def run_on_dataset(
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the Chain or language model on a dataset and store traces
|
||||
@ -676,6 +724,11 @@ def run_on_dataset(
|
||||
will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the project.
|
||||
run_evaluators: Evaluators to run on the results of the chain.
|
||||
input_mapper: A function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the run's project name and the resulting model outputs.
|
||||
@ -693,6 +746,7 @@ def run_on_dataset(
|
||||
tags=tags,
|
||||
run_evaluators=run_evaluators,
|
||||
client=client_,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
|
@ -10,13 +10,16 @@ from langchainplus_sdk.schemas import Dataset, Example
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.transform import TransformChain
|
||||
from langchain.client.runner_utils import (
|
||||
InputFormatError,
|
||||
_get_messages,
|
||||
_get_prompts,
|
||||
arun_on_dataset,
|
||||
run_llm,
|
||||
run_llm_or_chain,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@ -75,6 +78,57 @@ def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
||||
_get_prompts(inputs)
|
||||
|
||||
|
||||
def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
example = Example(
|
||||
id=uuid.uuid4(),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"the wrong input": "1", "another key": "2"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
def run_val(inputs: dict) -> dict:
|
||||
assert "the right input" in inputs
|
||||
return {"output": "2"}
|
||||
|
||||
mock_chain = TransformChain(
|
||||
input_variables=["the right input"],
|
||||
output_variables=["output"],
|
||||
transform=run_val,
|
||||
)
|
||||
|
||||
def input_mapper(inputs: dict) -> dict:
|
||||
assert "the wrong input" in inputs
|
||||
return {"the right input": inputs["the wrong input"]}
|
||||
|
||||
result = run_llm_or_chain(
|
||||
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"output": "2", "the right input": "1"}
|
||||
bad_result = run_llm_or_chain(
|
||||
example,
|
||||
lambda: mock_chain,
|
||||
n_repetitions=1,
|
||||
)
|
||||
assert len(bad_result) == 1
|
||||
assert "Error" in bad_result[0]
|
||||
|
||||
# Try with LLM
|
||||
def llm_input_mapper(inputs: dict) -> List[str]:
|
||||
assert "the wrong input" in inputs
|
||||
return ["the right input"]
|
||||
|
||||
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
|
||||
result = run_llm_or_chain(
|
||||
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
|
||||
)
|
||||
assert len(result) == 1
|
||||
llm_result = result[0]
|
||||
assert isinstance(llm_result, LLMResult)
|
||||
assert llm_result.generations[0][0].text == "somenumber"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
@ -171,6 +225,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
n_repetitions: int,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
|
Loading…
Reference in New Issue
Block a user