mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
check memory variables (#411)
can have multiple input keys, if some come from memory
This commit is contained in:
parent
f990395211
commit
20959d8c36
@ -963,7 +963,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -37,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "91d307ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -56,7 +56,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "10a93bf9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -77,7 +77,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "fa0f3066",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -96,7 +96,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "8465b4b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -107,7 +107,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"id": "611be801",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -143,7 +143,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"id": "b6255b02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -153,7 +153,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"id": "ec4eacad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -163,17 +163,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 9,
|
||||
"id": "59c7508d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" The president said that Ketanji Brown Jackson is one of our nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. The president also said that Ketanji Brown Jackson is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
||||
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds and will continue Justice Breyer's legacy of excellence.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -88,14 +88,19 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
if len(self.input_keys) != 1:
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({self.input_keys}). When a chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {self.input_keys[0]: inputs}
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
|
@ -1,10 +1,31 @@
|
||||
"""Test logic on base chain class."""
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import Chain, Memory
|
||||
|
||||
|
||||
class FakeMemory(Memory, BaseModel):
|
||||
"""Fake memory class for testing purposes."""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Return baz variable."""
|
||||
return ["baz"]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return baz variable."""
|
||||
return {"baz": "foo"}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
@ -106,3 +127,9 @@ def test_multiple_output_keys_error() -> None:
|
||||
chain = FakeChain(the_output_keys=["foo", "bar"])
|
||||
with pytest.raises(ValueError):
|
||||
chain.run("bar")
|
||||
|
||||
|
||||
def test_run_arg_with_memory() -> None:
|
||||
"""Test run method works when arg is passed."""
|
||||
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
|
||||
chain.run("bar")
|
||||
|
Loading…
Reference in New Issue
Block a user