diff --git a/langchain/chains/base.py b/langchain/chains/base.py index e041cdea298..848664e5856 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -1,6 +1,6 @@ """Base interface that all chains should implement.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Extra, Field @@ -74,18 +74,28 @@ class Chain(BaseModel, ABC): """Run the logic of this chain and return the output.""" def __call__( - self, inputs: Dict[str, Any], return_only_outputs: bool = False + self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False ) -> Dict[str, str]: """Run the logic of this chain and add to output if desired. Args: - inputs: Dictionary of inputs. + inputs: Dictionary of inputs, or single input if chain expects + only one param. return_only_outputs: boolean for whether to return only outputs in the response. If True, only new keys generated by this chain will be returned. If False, both input keys and new keys generated by this chain will be returned. Defaults to False. """ + if not isinstance(inputs, dict): + if len(self.input_keys) != 1: + raise ValueError( + f"A single string input was passed in, but this chain expects " + f"multiple inputs ({self.input_keys}). When a chain expects " + f"multiple inputs, please call it by passing in a dictionary, " + "eg `chain({'foo': 1, 'bar': 2})`" + ) + inputs = {self.input_keys[0]: inputs} if self.memory is not None: external_context = self.memory.load_memory_variables(inputs) inputs = dict(inputs, **external_context) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 2ab19c31b70..8fcfa918168 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -11,11 +11,12 @@ class FakeChain(Chain, BaseModel): """Fake chain class for testing purposes.""" be_correct: bool = True + the_input_keys: List[str] = ["foo"] @property def input_keys(self) -> List[str]: - """Input key of foo.""" - return ["foo"] + """Input keys.""" + return self.the_input_keys @property def output_keys(self) -> List[str]: @@ -48,3 +49,17 @@ def test_correct_call() -> None: chain = FakeChain() output = chain({"foo": "bar"}) assert output == {"foo": "bar", "bar": "baz"} + + +def test_single_input_correct() -> None: + """Test passing single input works.""" + chain = FakeChain() + output = chain("bar") + assert output == {"foo": "bar", "bar": "baz"} + + +def test_single_input_error() -> None: + """Test passing single input errors as expected.""" + chain = FakeChain(the_input_keys=["foo", "bar"]) + with pytest.raises(ValueError): + chain("bar")