From a982d2f27dc2c8804aaf5a88086d4abcc681c9ed Mon Sep 17 00:00:00 2001 From: Mikolaj Mikuliszyn Date: Sun, 27 Apr 2025 16:41:00 +0100 Subject: [PATCH] add save/load methods for persona, more unit tests Changes - Add save_context to PersonaMemory - Add load_memory_variables to PersonaMemory - Expand unit tests for memory save/load - Add full conversation simulation test - Maintain full backward compatibility with other modules (hopefully) --- libs/langchain/langchain/memory/persona.py | 71 +++++- .../tests/unit_tests/memory/test_persona.py | 232 +++++++++++++++--- 2 files changed, 258 insertions(+), 45 deletions(-) diff --git a/libs/langchain/langchain/memory/persona.py b/libs/langchain/langchain/memory/persona.py index 3a601c84fea..bfc4fb613af 100644 --- a/libs/langchain/langchain/memory/persona.py +++ b/libs/langchain/langchain/memory/persona.py @@ -4,15 +4,25 @@ import re from langchain_core.memory import BaseMemory + class EnrichedMessage(BaseModel): """A message enriched with persona traits and metadata.""" + id: str content: str traits: List[str] = Field(default_factory=list) metadata: Dict[str, Any] = Field(default_factory=dict) + class PersonaMemory(BaseMemory): - """Memory that tracks evolving agent persona traits over a conversation.""" + """ + Memory module that dynamically tracks emotional and behavioral traits + exhibited by an agent over the course of a conversation. + + Traits are automatically detected from output messages and + accumulated across interactions. Recent messages and detected traits + can be retrieved to enrich future prompts or modify agent behavior. + """ memory_key: str = "persona" input_key: str = "input" @@ -32,8 +42,9 @@ class PersonaMemory(BaseMemory): self.recent_messages = [] def _detect_traits(self, text: str) -> Dict[str, int]: - """Detect persona traits using both a default method and optionally an external engine. - + """ + Detect persona traits using both a default method and optionally an external engine. + Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded) """ # Always run the fast, local simple detection first @@ -42,16 +53,17 @@ class PersonaMemory(BaseMemory): "enthusiastic": ["!", "awesome", "great job", "fantastic"], "formal": ["Dear", "Sincerely", "Respectfully"], "cautious": ["maybe", "perhaps", "I think", "it could be"], - "friendly": ["Hi", "Hey", "Hello", "Good to see you"], + "hesitant": ["maybe", "...", "I'm not sure", "perhaps", "unsure"], + "friendly": ["Hi", "Hey", "Hello", "Good to see you", "friendly"], "curious": ["?", "I wonder", "Could you"], - "hesitant": ["...", "I'm not sure", "perhaps"], + "analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"], } trait_hits: Dict[str, int] = {} lowered_text = text.lower() - + # Clean the text for word matching by removing punctuation - clean_text = re.sub(r'[^\w\s]', ' ', lowered_text) + clean_text = re.sub(r"[^\w\s]", " ", lowered_text) words = clean_text.split() for trait, patterns in trait_patterns.items(): @@ -59,8 +71,7 @@ class PersonaMemory(BaseMemory): for pattern in patterns: pattern_lower = pattern.lower() if trait == "friendly": - if (lowered_text.startswith(pattern_lower) or - pattern_lower in words): + if lowered_text.startswith(pattern_lower) or pattern_lower in words: count += 1 elif trait == "enthusiastic": if pattern == "!": @@ -68,7 +79,7 @@ class PersonaMemory(BaseMemory): else: if pattern_lower in lowered_text: count += 1 - + if count > 0: trait_hits[trait] = count @@ -83,3 +94,43 @@ class PersonaMemory(BaseMemory): # Fall back to simple default detection if external fails or is unavailable return trait_hits + + def load_memory_variables( + self, inputs: Dict[str, Any], include_messages: bool = False + ) -> Dict[str, Any]: + """Return the stored persona traits and optionally recent conversation messages.""" + + memory_data = {"traits": self.persona_traits.copy()} + + if include_messages: + memory_data["recent_messages"] = [ + message.model_dump() for message in self.recent_messages + ] + + return {self.memory_key: memory_data} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Analyze outputs and update the persona traits and recent messages.""" + + input_text = inputs.get(self.input_key, "") + output_text = outputs.get(self.output_key, "") + + traits_detected = self._detect_traits(output_text) + + message = EnrichedMessage( + id=str(len(self.recent_messages) + 1), + content=output_text, + traits=list(traits_detected.keys()), + metadata={"traits_count": traits_detected}, + ) + + self.recent_messages.append(message) + + # Trim messages to maintain memory size k + if len(self.recent_messages) > self.k: + self.recent_messages = self.recent_messages[-self.k :] + + # Update persona traits + self.persona_traits = list( + {trait for msg in self.recent_messages for trait in msg.traits} + ) diff --git a/libs/langchain/tests/unit_tests/memory/test_persona.py b/libs/langchain/tests/unit_tests/memory/test_persona.py index 01f2a5ad42c..c06703ffd11 100644 --- a/libs/langchain/tests/unit_tests/memory/test_persona.py +++ b/libs/langchain/tests/unit_tests/memory/test_persona.py @@ -7,31 +7,27 @@ from langchain.schema import AIMessage, HumanMessage, BaseMessage import json from typing import List, Optional -@contextmanager -def capture_output(): - """Capture stdout and stderr.""" - new_out, new_err = StringIO(), StringIO() - old_out, old_err = sys.stdout, sys.stderr - try: - sys.stdout, sys.stderr = new_out, new_err - yield sys.stdout, sys.stderr - finally: - sys.stdout, sys.stderr = old_out, old_err class MockChatModel: """A simple mock chat model that can simulate both success and failure cases.""" - - def __init__(self, response: Optional[str] = None, should_fail: bool = False, error_msg: str = ""): + + def __init__( + self, + response: Optional[str] = None, + should_fail: bool = False, + error_msg: str = "", + ): self.response = response self.should_fail = should_fail self.error_msg = error_msg - + def invoke(self, messages: List[BaseMessage], **kwargs) -> AIMessage: """Simulate chat model invocation.""" if self.should_fail: raise ValueError(self.error_msg) return AIMessage(content=self.response) + class TestPersonaMemory(unittest.TestCase): def setUp(self): """Set up test-level attributes.""" @@ -40,13 +36,9 @@ class TestPersonaMemory(unittest.TestCase): def _create_successful_chat_model(self): """Create a mock chat model that returns successful responses.""" - response = json.dumps({ - "traits": { - "confident": 2, - "analytical": 1, - "professional": 1 - } - }) + response = json.dumps( + {"traits": {"confident": 2, "analytical": 1, "professional": 1}} + ) return MockChatModel(response=response) def _create_failing_chat_model(self, error_type: str = "rate_limit"): @@ -57,14 +49,11 @@ class TestPersonaMemory(unittest.TestCase): error_msg = "Invalid API key" else: error_msg = "Internal server error" - + return MockChatModel(should_fail=True, error_msg=error_msg) def test_enriched_message_creation(self): - input_data = { - "id": "msg-001", - "content": "Hello!" - } + input_data = {"id": "msg-001", "content": "Hello!"} message = EnrichedMessage(**input_data) self.assertEqual(message.id, "msg-001") self.assertEqual(message.content, "Hello!") @@ -129,7 +118,7 @@ class TestPersonaMemory(unittest.TestCase): memory = PersonaMemory() output_texts = [ "I'm sorry, truly sorry about the confusion.", - "Hello there! Great job on that task!" + "Hello there! Great job on that task!", ] for text in output_texts: @@ -153,12 +142,12 @@ class TestPersonaMemory(unittest.TestCase): def test_successful_api_detection(self): chat_model = self._create_successful_chat_model() - + def trait_detector(text: str): messages = [HumanMessage(content=text)] response = chat_model.invoke(messages) return json.loads(response.content)["traits"] - + memory = PersonaMemory(trait_detection_engine=trait_detector) text = "Based on the analysis, I can confidently say this is the best approach." traits = memory._detect_traits(text) @@ -171,12 +160,12 @@ class TestPersonaMemory(unittest.TestCase): def test_rate_limit_failure(self): chat_model = self._create_failing_chat_model("rate_limit") - + def trait_detector(text: str): messages = [HumanMessage(content=text)] response = chat_model.invoke(messages) return json.loads(response.content)["traits"] - + memory = PersonaMemory(trait_detection_engine=trait_detector) text = "I'm so sorry about the confusion." traits = memory._detect_traits(text) @@ -184,12 +173,12 @@ class TestPersonaMemory(unittest.TestCase): def test_authentication_failure(self): chat_model = self._create_failing_chat_model("authentication") - + def trait_detector(text: str): messages = [HumanMessage(content=text)] response = chat_model.invoke(messages) return json.loads(response.content)["traits"] - + memory = PersonaMemory(trait_detection_engine=trait_detector) text = "I apologize for the mistake." traits = memory._detect_traits(text) @@ -197,16 +186,189 @@ class TestPersonaMemory(unittest.TestCase): def test_server_error_failure(self): chat_model = self._create_failing_chat_model("server_error") - + def trait_detector(text: str): messages = [HumanMessage(content=text)] response = chat_model.invoke(messages) return json.loads(response.content)["traits"] - + memory = PersonaMemory(trait_detection_engine=trait_detector) text = "I'm very sorry about that." traits = memory._detect_traits(text) self.assertIn("apologetic", traits) # Should fall back to default detection + def test_save_context_updates_memory(self): + memory = PersonaMemory() + input_text = "Hello!" + output_text = "I'm sorry for the confusion." + + memory.save_context({"input": input_text}, {"output": output_text}) + self.assertIn("apologetic", memory.persona_traits) + self.assertEqual(len(memory.recent_messages), 1) + self.assertEqual(memory.recent_messages[0].content, output_text) + self.assertIn("apologetic", memory.recent_messages[0].traits) + + def test_load_memory_variables_returns_traits(self): + memory = PersonaMemory() + memory.persona_traits = ["friendly", "apologetic"] + + result = memory.load_memory_variables({}) + self.assertEqual(result["persona"], {"traits": ["friendly", "apologetic"]}) + + def test_save_context_updates_memory_with_multiple_traits(self): + memory = PersonaMemory() + input_text = "Hello!" + output_text = "I'm sorry for the confusion. I'm also friendly and analytical." + + memory.save_context({"input": input_text}, {"output": output_text}) + self.assertIn("apologetic", memory.persona_traits) + self.assertIn("friendly", memory.persona_traits) + self.assertIn("analytical", memory.persona_traits) + self.assertEqual(len(memory.recent_messages), 1) + self.assertEqual(memory.recent_messages[0].content, output_text) + self.assertIn("apologetic", memory.recent_messages[0].traits) + self.assertIn("friendly", memory.recent_messages[0].traits) + self.assertIn("analytical", memory.recent_messages[0].traits) + + def test_memory_trims_to_k_messages(self): + memory = PersonaMemory(k=2) + + memory.save_context({"input": "Hi"}, {"output": "Sorry about that!"}) + memory.save_context({"input": "Hello"}, {"output": "Apologies again!"}) + memory.save_context({"input": "Hey"}, {"output": "Great job!"}) # Third message + + self.assertEqual(len(memory.recent_messages), 2) + self.assertEqual(memory.recent_messages[0].content, "Apologies again!") + self.assertEqual(memory.recent_messages[1].content, "Great job!") + + def test_load_memory_variables_with_messages(self): + memory = PersonaMemory() + output_text = "Sorry about the mistake!" + + memory.save_context({"input": "Hello"}, {"output": output_text}) + + result = memory.load_memory_variables({}, include_messages=True) + self.assertIn("traits", result["persona"]) + self.assertIn("recent_messages", result["persona"]) + self.assertEqual( + result["persona"]["recent_messages"][0]["content"], output_text + ) + + def test_save_context_with_missing_output(self): + memory = PersonaMemory() + + memory.save_context({"input": "Hi"}, {}) # No output provided + + self.assertEqual(len(memory.recent_messages), 1) + self.assertEqual( + memory.recent_messages[0].content, "" + ) # Should be empty string + self.assertEqual(memory.recent_messages[0].traits, []) # No traits + + def test_double_save_context_creates_two_entries(self): + memory = PersonaMemory() + + output_text = "Sorry for that mistake." + memory.save_context({"input": "Hi"}, {"output": output_text}) + memory.save_context( + {"input": "Hey"}, {"output": output_text} + ) # Same output again + + self.assertEqual(len(memory.recent_messages), 2) + self.assertEqual(memory.recent_messages[0].content, output_text) + self.assertEqual(memory.recent_messages[1].content, output_text) + + def test_full_conversation_simulation_with_failures(self): + """Simulate a longer conversation with mixed success and fallback during trait detection.""" + memory = PersonaMemory() + + conversation = [ + ("Hi there!", "I'm so happy to meet you!"), # Enthusiastic + ( + "Could you help me?", + "Of course! Happy to assist.", + ), # Friendly + Enthusiastic + ( + "I'm unsure about this plan...", + "Maybe we should rethink it.", + ), # Hesitant + Cautious (simulated failure here) + ( + "Sorry about the delay.", + "Apologies! It won't happen again.", + ), # Apologetic + ] + + print("\n\n--- Starting full conversation simulation ---\n") + + # Simulate that at message 3, external engine fails (mock failure) + def dynamic_trait_detector(text: str): + if "rethink" in text: + print("Simulated API Failure during trait detection.\n") + raise ValueError("Simulated API failure during cautious response") + print(f"Simulated API Success - analyzing text: '{text}'\n") + if "happy to meet you" in text: + return {"enthusiastic": 1, "friendly": 1, "mocked_positive": 1} + elif "Happy to assist" in text: + return {"friendly": 1, "enthusiastic": 1, "mocked_positive": 1} + elif "Apologies" in text: + return {"apologetic": 1, "mocked_positive": 1} + return {"mocked_positive": 1} + + memory.trait_detection_engine = dynamic_trait_detector + + for idx, (user_input, agent_output) in enumerate(conversation): + print(f"--- Message {idx+1} ---") + print(f"Input: {user_input}") + print(f"Output: {agent_output}\n") + memory.save_context({"input": user_input}, {"output": agent_output}) + + last_message = memory.recent_messages[-1] + print(f"Saved Message Traits: {last_message.traits}\n") + + print("--- Final accumulated traits ---") + print(memory.persona_traits) + print("\n--- End of conversation simulation ---\n") + + # Assertions: + # There should be 4 recent messages (conversation length) + self.assertEqual(len(memory.recent_messages), 4) + + # Traits should have accumulated normally, fallback triggered once + detected_traits = set(memory.persona_traits) + + expected_traits = { + "enthusiastic", + "friendly", + "hesitant", + "cautious", + "apologetic", + "mocked_positive", + } + + print(f"Detected traits: {detected_traits}") + print(f"Expected traits: {expected_traits}") + + for trait in expected_traits: + self.assertIn(trait, detected_traits) + + # Messages should have traits assigned + for message in memory.recent_messages: + self.assertIsInstance(message.traits, list) + self.assertGreaterEqual(len(message.traits), 1) + + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() + + +""" +Just the convo simulation: +python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona.TestPersonaMemory.test_full_conversation_simulation_with_failure + +All tests: +python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona + +or quitetly: +python3 -m unittest -q libs.langchain.tests.unit_tests.memory.test_persona + +"""