From b913df3774aa605fd004b0473b438f34faa8816d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 24 Nov 2022 20:11:29 -0800 Subject: [PATCH] make attrs public (#187) since they are used outside of the class, should be public --- langchain/chains/base.py | 8 ++++---- langchain/chains/conversation/memory.py | 8 ++++---- tests/unit_tests/chains/test_conversation.py | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index e69855bd41b..cdfae440103 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -20,11 +20,11 @@ class Memory(BaseModel, ABC): """Input keys this memory class will load dynamically.""" @abstractmethod - def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return key-value pairs given the text input to the chain.""" @abstractmethod - def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save the context of this model run to memory.""" @@ -77,7 +77,7 @@ class Chain(BaseModel, ABC): """ if self.memory is not None: - external_context = self.memory._load_dynamic_keys(inputs) + external_context = self.memory.load_dynamic_keys(inputs) inputs = dict(inputs, **external_context) self._validate_inputs(inputs) if self.verbose: @@ -87,7 +87,7 @@ class Chain(BaseModel, ABC): print("\n\033[1m> Finished chain.\033[0m") self._validate_outputs(outputs) if self.memory is not None: - self.memory._save_context(inputs, outputs) + self.memory.save_context(inputs, outputs) if return_only_outputs: return outputs else: diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index 463a49f4b20..7e822921731 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -24,11 +24,11 @@ class ConversationBufferMemory(Memory, BaseModel): """ return [self.dynamic_key] - def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return history buffer.""" return {self.dynamic_key: self.buffer} - def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" prompt_input_keys = list(set(inputs).difference(self.dynamic_keys)) if len(prompt_input_keys) != 1: @@ -56,7 +56,7 @@ class ConversationSummaryMemory(Memory, BaseModel): """ return [self.dynamic_key] - def _load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def load_dynamic_keys(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return history buffer.""" return {self.dynamic_key: self.buffer} @@ -72,7 +72,7 @@ class ConversationSummaryMemory(Memory, BaseModel): ) return values - def _save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" prompt_input_keys = list(set(inputs).difference(self.dynamic_keys)) if len(prompt_input_keys) != 1: diff --git a/tests/unit_tests/chains/test_conversation.py b/tests/unit_tests/chains/test_conversation.py index dfb706e1fa1..4e5cc244030 100644 --- a/tests/unit_tests/chains/test_conversation.py +++ b/tests/unit_tests/chains/test_conversation.py @@ -50,19 +50,19 @@ def test_conversation_memory(memory: Memory) -> None: good_inputs = {"foo": "bar", "baz": "foo"} # This is a good output because these is one variable. good_outputs = {"bar": "foo"} - memory._save_context(good_inputs, good_outputs) + memory.save_context(good_inputs, good_outputs) # This is a bad input because there are two variables that aren't the same as baz. bad_inputs = {"foo": "bar", "foo1": "bar"} with pytest.raises(ValueError): - memory._save_context(bad_inputs, good_outputs) + memory.save_context(bad_inputs, good_outputs) # This is a bad input because the only variable is the same as baz. bad_inputs = {"baz": "bar"} with pytest.raises(ValueError): - memory._save_context(bad_inputs, good_outputs) + memory.save_context(bad_inputs, good_outputs) # This is a bad output because it is empty. with pytest.raises(ValueError): - memory._save_context(good_inputs, {}) + memory.save_context(good_inputs, {}) # This is a bad output because there are two keys. bad_outputs = {"foo": "bar", "foo1": "bar"} with pytest.raises(ValueError): - memory._save_context(good_inputs, bad_outputs) + memory.save_context(good_inputs, bad_outputs)