diff --git a/docs/examples/chains/chatgpt_clone.ipynb b/docs/examples/chains/chatgpt_clone.ipynb index f958b9234fe..64fe36f7504 100644 --- a/docs/examples/chains/chatgpt_clone.ipynb +++ b/docs/examples/chains/chatgpt_clone.ipynb @@ -963,7 +963,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/docs/getting_started/data_augmented_generation.ipynb b/docs/getting_started/data_augmented_generation.ipynb index 51bcf663301..177ae8e9ce4 100644 --- a/docs/getting_started/data_augmented_generation.ipynb +++ b/docs/getting_started/data_augmented_generation.ipynb @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "91d307ed", "metadata": {}, "outputs": [], @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "10a93bf9", "metadata": {}, "outputs": [], @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "fa0f3066", "metadata": {}, "outputs": [], @@ -96,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "8465b4b7", "metadata": {}, "outputs": [], @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "611be801", "metadata": {}, "outputs": [ @@ -143,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "b6255b02", "metadata": {}, "outputs": [], @@ -153,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "ec4eacad", "metadata": {}, "outputs": [], @@ -163,17 +163,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "59c7508d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "\" The president said that Ketanji Brown Jackson is one of our nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. The president also said that Ketanji Brown Jackson is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" + "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds and will continue Justice Breyer's legacy of excellence.\"" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 123794a320e..24f90ea0bcd 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -88,14 +88,19 @@ class Chain(BaseModel, ABC): """ if not isinstance(inputs, dict): - if len(self.input_keys) != 1: + _input_keys = set(self.input_keys) + if self.memory is not None: + # If there are multiple input keys, but some get set by memory so that + # only one is not set, we can still figure out which key it is. + _input_keys = _input_keys.difference(self.memory.memory_variables) + if len(_input_keys) != 1: raise ValueError( f"A single string input was passed in, but this chain expects " - f"multiple inputs ({self.input_keys}). When a chain expects " + f"multiple inputs ({_input_keys}). When a chain expects " f"multiple inputs, please call it by passing in a dictionary, " "eg `chain({'foo': 1, 'bar': 2})`" ) - inputs = {self.input_keys[0]: inputs} + inputs = {list(_input_keys)[0]: inputs} if self.memory is not None: external_context = self.memory.load_memory_variables(inputs) inputs = dict(inputs, **external_context) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index ade2d318eef..69dfd6bf7cc 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -1,10 +1,31 @@ """Test logic on base chain class.""" -from typing import Dict, List +from typing import Any, Dict, List import pytest from pydantic import BaseModel -from langchain.chains.base import Chain +from langchain.chains.base import Chain, Memory + + +class FakeMemory(Memory, BaseModel): + """Fake memory class for testing purposes.""" + + @property + def memory_variables(self) -> List[str]: + """Return baz variable.""" + return ["baz"] + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """Return baz variable.""" + return {"baz": "foo"} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Pass.""" + pass + + def clear(self) -> None: + """Pass.""" + pass class FakeChain(Chain, BaseModel): @@ -106,3 +127,9 @@ def test_multiple_output_keys_error() -> None: chain = FakeChain(the_output_keys=["foo", "bar"]) with pytest.raises(ValueError): chain.run("bar") + + +def test_run_arg_with_memory() -> None: + """Test run method works when arg is passed.""" + chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory()) + chain.run("bar")