add save/load methods for persona, more unit tests

Changes
- Add save_context to PersonaMemory
- Add load_memory_variables to PersonaMemory
- Expand unit tests for memory save/load
- Add full conversation simulation test
- Maintain full backward compatibility with other modules (hopefully)
This commit is contained in:
Mikolaj Mikuliszyn 2025-04-27 16:41:00 +01:00
parent 05a80bdf49
commit a982d2f27d
2 changed files with 258 additions and 45 deletions

View File

@ -4,15 +4,25 @@ import re
from langchain_core.memory import BaseMemory
class EnrichedMessage(BaseModel):
"""A message enriched with persona traits and metadata."""
id: str
content: str
traits: List[str] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
class PersonaMemory(BaseMemory):
"""Memory that tracks evolving agent persona traits over a conversation."""
"""
Memory module that dynamically tracks emotional and behavioral traits
exhibited by an agent over the course of a conversation.
Traits are automatically detected from output messages and
accumulated across interactions. Recent messages and detected traits
can be retrieved to enrich future prompts or modify agent behavior.
"""
memory_key: str = "persona"
input_key: str = "input"
@ -32,8 +42,9 @@ class PersonaMemory(BaseMemory):
self.recent_messages = []
def _detect_traits(self, text: str) -> Dict[str, int]:
"""Detect persona traits using both a default method and optionally an external engine.
"""
Detect persona traits using both a default method and optionally an external engine.
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
"""
# Always run the fast, local simple detection first
@ -42,16 +53,17 @@ class PersonaMemory(BaseMemory):
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
"formal": ["Dear", "Sincerely", "Respectfully"],
"cautious": ["maybe", "perhaps", "I think", "it could be"],
"friendly": ["Hi", "Hey", "Hello", "Good to see you"],
"hesitant": ["maybe", "...", "I'm not sure", "perhaps", "unsure"],
"friendly": ["Hi", "Hey", "Hello", "Good to see you", "friendly"],
"curious": ["?", "I wonder", "Could you"],
"hesitant": ["...", "I'm not sure", "perhaps"],
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
}
trait_hits: Dict[str, int] = {}
lowered_text = text.lower()
# Clean the text for word matching by removing punctuation
clean_text = re.sub(r'[^\w\s]', ' ', lowered_text)
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
words = clean_text.split()
for trait, patterns in trait_patterns.items():
@ -59,8 +71,7 @@ class PersonaMemory(BaseMemory):
for pattern in patterns:
pattern_lower = pattern.lower()
if trait == "friendly":
if (lowered_text.startswith(pattern_lower) or
pattern_lower in words):
if lowered_text.startswith(pattern_lower) or pattern_lower in words:
count += 1
elif trait == "enthusiastic":
if pattern == "!":
@ -68,7 +79,7 @@ class PersonaMemory(BaseMemory):
else:
if pattern_lower in lowered_text:
count += 1
if count > 0:
trait_hits[trait] = count
@ -83,3 +94,43 @@ class PersonaMemory(BaseMemory):
# Fall back to simple default detection if external fails or is unavailable
return trait_hits
def load_memory_variables(
self, inputs: Dict[str, Any], include_messages: bool = False
) -> Dict[str, Any]:
"""Return the stored persona traits and optionally recent conversation messages."""
memory_data = {"traits": self.persona_traits.copy()}
if include_messages:
memory_data["recent_messages"] = [
message.model_dump() for message in self.recent_messages
]
return {self.memory_key: memory_data}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Analyze outputs and update the persona traits and recent messages."""
input_text = inputs.get(self.input_key, "")
output_text = outputs.get(self.output_key, "")
traits_detected = self._detect_traits(output_text)
message = EnrichedMessage(
id=str(len(self.recent_messages) + 1),
content=output_text,
traits=list(traits_detected.keys()),
metadata={"traits_count": traits_detected},
)
self.recent_messages.append(message)
# Trim messages to maintain memory size k
if len(self.recent_messages) > self.k:
self.recent_messages = self.recent_messages[-self.k :]
# Update persona traits
self.persona_traits = list(
{trait for msg in self.recent_messages for trait in msg.traits}
)

View File

@ -7,31 +7,27 @@ from langchain.schema import AIMessage, HumanMessage, BaseMessage
import json
from typing import List, Optional
@contextmanager
def capture_output():
"""Capture stdout and stderr."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
class MockChatModel:
"""A simple mock chat model that can simulate both success and failure cases."""
def __init__(self, response: Optional[str] = None, should_fail: bool = False, error_msg: str = ""):
def __init__(
self,
response: Optional[str] = None,
should_fail: bool = False,
error_msg: str = "",
):
self.response = response
self.should_fail = should_fail
self.error_msg = error_msg
def invoke(self, messages: List[BaseMessage], **kwargs) -> AIMessage:
"""Simulate chat model invocation."""
if self.should_fail:
raise ValueError(self.error_msg)
return AIMessage(content=self.response)
class TestPersonaMemory(unittest.TestCase):
def setUp(self):
"""Set up test-level attributes."""
@ -40,13 +36,9 @@ class TestPersonaMemory(unittest.TestCase):
def _create_successful_chat_model(self):
"""Create a mock chat model that returns successful responses."""
response = json.dumps({
"traits": {
"confident": 2,
"analytical": 1,
"professional": 1
}
})
response = json.dumps(
{"traits": {"confident": 2, "analytical": 1, "professional": 1}}
)
return MockChatModel(response=response)
def _create_failing_chat_model(self, error_type: str = "rate_limit"):
@ -57,14 +49,11 @@ class TestPersonaMemory(unittest.TestCase):
error_msg = "Invalid API key"
else:
error_msg = "Internal server error"
return MockChatModel(should_fail=True, error_msg=error_msg)
def test_enriched_message_creation(self):
input_data = {
"id": "msg-001",
"content": "Hello!"
}
input_data = {"id": "msg-001", "content": "Hello!"}
message = EnrichedMessage(**input_data)
self.assertEqual(message.id, "msg-001")
self.assertEqual(message.content, "Hello!")
@ -129,7 +118,7 @@ class TestPersonaMemory(unittest.TestCase):
memory = PersonaMemory()
output_texts = [
"I'm sorry, truly sorry about the confusion.",
"Hello there! Great job on that task!"
"Hello there! Great job on that task!",
]
for text in output_texts:
@ -153,12 +142,12 @@ class TestPersonaMemory(unittest.TestCase):
def test_successful_api_detection(self):
chat_model = self._create_successful_chat_model()
def trait_detector(text: str):
messages = [HumanMessage(content=text)]
response = chat_model.invoke(messages)
return json.loads(response.content)["traits"]
memory = PersonaMemory(trait_detection_engine=trait_detector)
text = "Based on the analysis, I can confidently say this is the best approach."
traits = memory._detect_traits(text)
@ -171,12 +160,12 @@ class TestPersonaMemory(unittest.TestCase):
def test_rate_limit_failure(self):
chat_model = self._create_failing_chat_model("rate_limit")
def trait_detector(text: str):
messages = [HumanMessage(content=text)]
response = chat_model.invoke(messages)
return json.loads(response.content)["traits"]
memory = PersonaMemory(trait_detection_engine=trait_detector)
text = "I'm so sorry about the confusion."
traits = memory._detect_traits(text)
@ -184,12 +173,12 @@ class TestPersonaMemory(unittest.TestCase):
def test_authentication_failure(self):
chat_model = self._create_failing_chat_model("authentication")
def trait_detector(text: str):
messages = [HumanMessage(content=text)]
response = chat_model.invoke(messages)
return json.loads(response.content)["traits"]
memory = PersonaMemory(trait_detection_engine=trait_detector)
text = "I apologize for the mistake."
traits = memory._detect_traits(text)
@ -197,16 +186,189 @@ class TestPersonaMemory(unittest.TestCase):
def test_server_error_failure(self):
chat_model = self._create_failing_chat_model("server_error")
def trait_detector(text: str):
messages = [HumanMessage(content=text)]
response = chat_model.invoke(messages)
return json.loads(response.content)["traits"]
memory = PersonaMemory(trait_detection_engine=trait_detector)
text = "I'm very sorry about that."
traits = memory._detect_traits(text)
self.assertIn("apologetic", traits) # Should fall back to default detection
def test_save_context_updates_memory(self):
memory = PersonaMemory()
input_text = "Hello!"
output_text = "I'm sorry for the confusion."
memory.save_context({"input": input_text}, {"output": output_text})
self.assertIn("apologetic", memory.persona_traits)
self.assertEqual(len(memory.recent_messages), 1)
self.assertEqual(memory.recent_messages[0].content, output_text)
self.assertIn("apologetic", memory.recent_messages[0].traits)
def test_load_memory_variables_returns_traits(self):
memory = PersonaMemory()
memory.persona_traits = ["friendly", "apologetic"]
result = memory.load_memory_variables({})
self.assertEqual(result["persona"], {"traits": ["friendly", "apologetic"]})
def test_save_context_updates_memory_with_multiple_traits(self):
memory = PersonaMemory()
input_text = "Hello!"
output_text = "I'm sorry for the confusion. I'm also friendly and analytical."
memory.save_context({"input": input_text}, {"output": output_text})
self.assertIn("apologetic", memory.persona_traits)
self.assertIn("friendly", memory.persona_traits)
self.assertIn("analytical", memory.persona_traits)
self.assertEqual(len(memory.recent_messages), 1)
self.assertEqual(memory.recent_messages[0].content, output_text)
self.assertIn("apologetic", memory.recent_messages[0].traits)
self.assertIn("friendly", memory.recent_messages[0].traits)
self.assertIn("analytical", memory.recent_messages[0].traits)
def test_memory_trims_to_k_messages(self):
memory = PersonaMemory(k=2)
memory.save_context({"input": "Hi"}, {"output": "Sorry about that!"})
memory.save_context({"input": "Hello"}, {"output": "Apologies again!"})
memory.save_context({"input": "Hey"}, {"output": "Great job!"}) # Third message
self.assertEqual(len(memory.recent_messages), 2)
self.assertEqual(memory.recent_messages[0].content, "Apologies again!")
self.assertEqual(memory.recent_messages[1].content, "Great job!")
def test_load_memory_variables_with_messages(self):
memory = PersonaMemory()
output_text = "Sorry about the mistake!"
memory.save_context({"input": "Hello"}, {"output": output_text})
result = memory.load_memory_variables({}, include_messages=True)
self.assertIn("traits", result["persona"])
self.assertIn("recent_messages", result["persona"])
self.assertEqual(
result["persona"]["recent_messages"][0]["content"], output_text
)
def test_save_context_with_missing_output(self):
memory = PersonaMemory()
memory.save_context({"input": "Hi"}, {}) # No output provided
self.assertEqual(len(memory.recent_messages), 1)
self.assertEqual(
memory.recent_messages[0].content, ""
) # Should be empty string
self.assertEqual(memory.recent_messages[0].traits, []) # No traits
def test_double_save_context_creates_two_entries(self):
memory = PersonaMemory()
output_text = "Sorry for that mistake."
memory.save_context({"input": "Hi"}, {"output": output_text})
memory.save_context(
{"input": "Hey"}, {"output": output_text}
) # Same output again
self.assertEqual(len(memory.recent_messages), 2)
self.assertEqual(memory.recent_messages[0].content, output_text)
self.assertEqual(memory.recent_messages[1].content, output_text)
def test_full_conversation_simulation_with_failures(self):
"""Simulate a longer conversation with mixed success and fallback during trait detection."""
memory = PersonaMemory()
conversation = [
("Hi there!", "I'm so happy to meet you!"), # Enthusiastic
(
"Could you help me?",
"Of course! Happy to assist.",
), # Friendly + Enthusiastic
(
"I'm unsure about this plan...",
"Maybe we should rethink it.",
), # Hesitant + Cautious (simulated failure here)
(
"Sorry about the delay.",
"Apologies! It won't happen again.",
), # Apologetic
]
print("\n\n--- Starting full conversation simulation ---\n")
# Simulate that at message 3, external engine fails (mock failure)
def dynamic_trait_detector(text: str):
if "rethink" in text:
print("Simulated API Failure during trait detection.\n")
raise ValueError("Simulated API failure during cautious response")
print(f"Simulated API Success - analyzing text: '{text}'\n")
if "happy to meet you" in text:
return {"enthusiastic": 1, "friendly": 1, "mocked_positive": 1}
elif "Happy to assist" in text:
return {"friendly": 1, "enthusiastic": 1, "mocked_positive": 1}
elif "Apologies" in text:
return {"apologetic": 1, "mocked_positive": 1}
return {"mocked_positive": 1}
memory.trait_detection_engine = dynamic_trait_detector
for idx, (user_input, agent_output) in enumerate(conversation):
print(f"--- Message {idx+1} ---")
print(f"Input: {user_input}")
print(f"Output: {agent_output}\n")
memory.save_context({"input": user_input}, {"output": agent_output})
last_message = memory.recent_messages[-1]
print(f"Saved Message Traits: {last_message.traits}\n")
print("--- Final accumulated traits ---")
print(memory.persona_traits)
print("\n--- End of conversation simulation ---\n")
# Assertions:
# There should be 4 recent messages (conversation length)
self.assertEqual(len(memory.recent_messages), 4)
# Traits should have accumulated normally, fallback triggered once
detected_traits = set(memory.persona_traits)
expected_traits = {
"enthusiastic",
"friendly",
"hesitant",
"cautious",
"apologetic",
"mocked_positive",
}
print(f"Detected traits: {detected_traits}")
print(f"Expected traits: {expected_traits}")
for trait in expected_traits:
self.assertIn(trait, detected_traits)
# Messages should have traits assigned
for message in memory.recent_messages:
self.assertIsInstance(message.traits, list)
self.assertGreaterEqual(len(message.traits), 1)
if __name__ == "__main__":
unittest.main()
unittest.main()
"""
Just the convo simulation:
python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona.TestPersonaMemory.test_full_conversation_simulation_with_failure
All tests:
python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona
or quitetly:
python3 -m unittest -q libs.langchain.tests.unit_tests.memory.test_persona
"""