mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
make attrs public (#187)
since they are used outside of the class, should be public
This commit is contained in:
parent
ae9c6257fe
commit
b913df3774
@ -20,11 +20,11 @@ class Memory(BaseModel, ABC):
|
||||
"""Input keys this memory class will load dynamically."""
|
||||
|
||||
@abstractmethod
|
||||
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
@abstractmethod
|
||||
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this model run to memory."""
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
if self.memory is not None:
|
||||
external_context = self.memory._load_dynamic_keys(inputs)
|
||||
external_context = self.memory.load_dynamic_keys(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
@ -87,7 +87,7 @@ class Chain(BaseModel, ABC):
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory._save_context(inputs, outputs)
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
|
@ -24,11 +24,11 @@ class ConversationBufferMemory(Memory, BaseModel):
|
||||
"""
|
||||
return [self.dynamic_key]
|
||||
|
||||
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.dynamic_key: self.buffer}
|
||||
|
||||
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
|
||||
if len(prompt_input_keys) != 1:
|
||||
@ -56,7 +56,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
|
||||
"""
|
||||
return [self.dynamic_key]
|
||||
|
||||
def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.dynamic_key: self.buffer}
|
||||
|
||||
@ -72,7 +72,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
prompt_input_keys = list(set(inputs).difference(self.dynamic_keys))
|
||||
if len(prompt_input_keys) != 1:
|
||||
|
@ -50,19 +50,19 @@ def test_conversation_memory(memory: Memory) -> None:
|
||||
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||
# This is a good output because these is one variable.
|
||||
good_outputs = {"bar": "foo"}
|
||||
memory._save_context(good_inputs, good_outputs)
|
||||
memory.save_context(good_inputs, good_outputs)
|
||||
# This is a bad input because there are two variables that aren't the same as baz.
|
||||
bad_inputs = {"foo": "bar", "foo1": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(bad_inputs, good_outputs)
|
||||
memory.save_context(bad_inputs, good_outputs)
|
||||
# This is a bad input because the only variable is the same as baz.
|
||||
bad_inputs = {"baz": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(bad_inputs, good_outputs)
|
||||
memory.save_context(bad_inputs, good_outputs)
|
||||
# This is a bad output because it is empty.
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(good_inputs, {})
|
||||
memory.save_context(good_inputs, {})
|
||||
# This is a bad output because there are two keys.
|
||||
bad_outputs = {"foo": "bar", "foo1": "bar"}
|
||||
with pytest.raises(ValueError):
|
||||
memory._save_context(good_inputs, bad_outputs)
|
||||
memory.save_context(good_inputs, bad_outputs)
|
||||
|
Loading…
Reference in New Issue
Block a user