mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
add save/load methods for persona, more unit tests
Changes - Add save_context to PersonaMemory - Add load_memory_variables to PersonaMemory - Expand unit tests for memory save/load - Add full conversation simulation test - Maintain full backward compatibility with other modules (hopefully)
This commit is contained in:
parent
05a80bdf49
commit
a982d2f27d
@ -4,15 +4,25 @@ import re
|
||||
|
||||
from langchain_core.memory import BaseMemory
|
||||
|
||||
|
||||
class EnrichedMessage(BaseModel):
|
||||
"""A message enriched with persona traits and metadata."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
traits: List[str] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PersonaMemory(BaseMemory):
|
||||
"""Memory that tracks evolving agent persona traits over a conversation."""
|
||||
"""
|
||||
Memory module that dynamically tracks emotional and behavioral traits
|
||||
exhibited by an agent over the course of a conversation.
|
||||
|
||||
Traits are automatically detected from output messages and
|
||||
accumulated across interactions. Recent messages and detected traits
|
||||
can be retrieved to enrich future prompts or modify agent behavior.
|
||||
"""
|
||||
|
||||
memory_key: str = "persona"
|
||||
input_key: str = "input"
|
||||
@ -32,7 +42,8 @@ class PersonaMemory(BaseMemory):
|
||||
self.recent_messages = []
|
||||
|
||||
def _detect_traits(self, text: str) -> Dict[str, int]:
|
||||
"""Detect persona traits using both a default method and optionally an external engine.
|
||||
"""
|
||||
Detect persona traits using both a default method and optionally an external engine.
|
||||
|
||||
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
|
||||
"""
|
||||
@ -42,16 +53,17 @@ class PersonaMemory(BaseMemory):
|
||||
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
||||
"formal": ["Dear", "Sincerely", "Respectfully"],
|
||||
"cautious": ["maybe", "perhaps", "I think", "it could be"],
|
||||
"friendly": ["Hi", "Hey", "Hello", "Good to see you"],
|
||||
"hesitant": ["maybe", "...", "I'm not sure", "perhaps", "unsure"],
|
||||
"friendly": ["Hi", "Hey", "Hello", "Good to see you", "friendly"],
|
||||
"curious": ["?", "I wonder", "Could you"],
|
||||
"hesitant": ["...", "I'm not sure", "perhaps"],
|
||||
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
|
||||
}
|
||||
|
||||
trait_hits: Dict[str, int] = {}
|
||||
lowered_text = text.lower()
|
||||
|
||||
# Clean the text for word matching by removing punctuation
|
||||
clean_text = re.sub(r'[^\w\s]', ' ', lowered_text)
|
||||
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
|
||||
words = clean_text.split()
|
||||
|
||||
for trait, patterns in trait_patterns.items():
|
||||
@ -59,8 +71,7 @@ class PersonaMemory(BaseMemory):
|
||||
for pattern in patterns:
|
||||
pattern_lower = pattern.lower()
|
||||
if trait == "friendly":
|
||||
if (lowered_text.startswith(pattern_lower) or
|
||||
pattern_lower in words):
|
||||
if lowered_text.startswith(pattern_lower) or pattern_lower in words:
|
||||
count += 1
|
||||
elif trait == "enthusiastic":
|
||||
if pattern == "!":
|
||||
@ -83,3 +94,43 @@ class PersonaMemory(BaseMemory):
|
||||
|
||||
# Fall back to simple default detection if external fails or is unavailable
|
||||
return trait_hits
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], include_messages: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Return the stored persona traits and optionally recent conversation messages."""
|
||||
|
||||
memory_data = {"traits": self.persona_traits.copy()}
|
||||
|
||||
if include_messages:
|
||||
memory_data["recent_messages"] = [
|
||||
message.model_dump() for message in self.recent_messages
|
||||
]
|
||||
|
||||
return {self.memory_key: memory_data}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Analyze outputs and update the persona traits and recent messages."""
|
||||
|
||||
input_text = inputs.get(self.input_key, "")
|
||||
output_text = outputs.get(self.output_key, "")
|
||||
|
||||
traits_detected = self._detect_traits(output_text)
|
||||
|
||||
message = EnrichedMessage(
|
||||
id=str(len(self.recent_messages) + 1),
|
||||
content=output_text,
|
||||
traits=list(traits_detected.keys()),
|
||||
metadata={"traits_count": traits_detected},
|
||||
)
|
||||
|
||||
self.recent_messages.append(message)
|
||||
|
||||
# Trim messages to maintain memory size k
|
||||
if len(self.recent_messages) > self.k:
|
||||
self.recent_messages = self.recent_messages[-self.k :]
|
||||
|
||||
# Update persona traits
|
||||
self.persona_traits = list(
|
||||
{trait for msg in self.recent_messages for trait in msg.traits}
|
||||
)
|
||||
|
@ -7,21 +7,16 @@ from langchain.schema import AIMessage, HumanMessage, BaseMessage
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
@contextmanager
|
||||
def capture_output():
|
||||
"""Capture stdout and stderr."""
|
||||
new_out, new_err = StringIO(), StringIO()
|
||||
old_out, old_err = sys.stdout, sys.stderr
|
||||
try:
|
||||
sys.stdout, sys.stderr = new_out, new_err
|
||||
yield sys.stdout, sys.stderr
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
class MockChatModel:
|
||||
"""A simple mock chat model that can simulate both success and failure cases."""
|
||||
|
||||
def __init__(self, response: Optional[str] = None, should_fail: bool = False, error_msg: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
response: Optional[str] = None,
|
||||
should_fail: bool = False,
|
||||
error_msg: str = "",
|
||||
):
|
||||
self.response = response
|
||||
self.should_fail = should_fail
|
||||
self.error_msg = error_msg
|
||||
@ -32,6 +27,7 @@ class MockChatModel:
|
||||
raise ValueError(self.error_msg)
|
||||
return AIMessage(content=self.response)
|
||||
|
||||
|
||||
class TestPersonaMemory(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test-level attributes."""
|
||||
@ -40,13 +36,9 @@ class TestPersonaMemory(unittest.TestCase):
|
||||
|
||||
def _create_successful_chat_model(self):
|
||||
"""Create a mock chat model that returns successful responses."""
|
||||
response = json.dumps({
|
||||
"traits": {
|
||||
"confident": 2,
|
||||
"analytical": 1,
|
||||
"professional": 1
|
||||
}
|
||||
})
|
||||
response = json.dumps(
|
||||
{"traits": {"confident": 2, "analytical": 1, "professional": 1}}
|
||||
)
|
||||
return MockChatModel(response=response)
|
||||
|
||||
def _create_failing_chat_model(self, error_type: str = "rate_limit"):
|
||||
@ -61,10 +53,7 @@ class TestPersonaMemory(unittest.TestCase):
|
||||
return MockChatModel(should_fail=True, error_msg=error_msg)
|
||||
|
||||
def test_enriched_message_creation(self):
|
||||
input_data = {
|
||||
"id": "msg-001",
|
||||
"content": "Hello!"
|
||||
}
|
||||
input_data = {"id": "msg-001", "content": "Hello!"}
|
||||
message = EnrichedMessage(**input_data)
|
||||
self.assertEqual(message.id, "msg-001")
|
||||
self.assertEqual(message.content, "Hello!")
|
||||
@ -129,7 +118,7 @@ class TestPersonaMemory(unittest.TestCase):
|
||||
memory = PersonaMemory()
|
||||
output_texts = [
|
||||
"I'm sorry, truly sorry about the confusion.",
|
||||
"Hello there! Great job on that task!"
|
||||
"Hello there! Great job on that task!",
|
||||
]
|
||||
|
||||
for text in output_texts:
|
||||
@ -208,5 +197,178 @@ class TestPersonaMemory(unittest.TestCase):
|
||||
traits = memory._detect_traits(text)
|
||||
self.assertIn("apologetic", traits) # Should fall back to default detection
|
||||
|
||||
def test_save_context_updates_memory(self):
|
||||
memory = PersonaMemory()
|
||||
input_text = "Hello!"
|
||||
output_text = "I'm sorry for the confusion."
|
||||
|
||||
memory.save_context({"input": input_text}, {"output": output_text})
|
||||
self.assertIn("apologetic", memory.persona_traits)
|
||||
self.assertEqual(len(memory.recent_messages), 1)
|
||||
self.assertEqual(memory.recent_messages[0].content, output_text)
|
||||
self.assertIn("apologetic", memory.recent_messages[0].traits)
|
||||
|
||||
def test_load_memory_variables_returns_traits(self):
|
||||
memory = PersonaMemory()
|
||||
memory.persona_traits = ["friendly", "apologetic"]
|
||||
|
||||
result = memory.load_memory_variables({})
|
||||
self.assertEqual(result["persona"], {"traits": ["friendly", "apologetic"]})
|
||||
|
||||
def test_save_context_updates_memory_with_multiple_traits(self):
|
||||
memory = PersonaMemory()
|
||||
input_text = "Hello!"
|
||||
output_text = "I'm sorry for the confusion. I'm also friendly and analytical."
|
||||
|
||||
memory.save_context({"input": input_text}, {"output": output_text})
|
||||
self.assertIn("apologetic", memory.persona_traits)
|
||||
self.assertIn("friendly", memory.persona_traits)
|
||||
self.assertIn("analytical", memory.persona_traits)
|
||||
self.assertEqual(len(memory.recent_messages), 1)
|
||||
self.assertEqual(memory.recent_messages[0].content, output_text)
|
||||
self.assertIn("apologetic", memory.recent_messages[0].traits)
|
||||
self.assertIn("friendly", memory.recent_messages[0].traits)
|
||||
self.assertIn("analytical", memory.recent_messages[0].traits)
|
||||
|
||||
def test_memory_trims_to_k_messages(self):
|
||||
memory = PersonaMemory(k=2)
|
||||
|
||||
memory.save_context({"input": "Hi"}, {"output": "Sorry about that!"})
|
||||
memory.save_context({"input": "Hello"}, {"output": "Apologies again!"})
|
||||
memory.save_context({"input": "Hey"}, {"output": "Great job!"}) # Third message
|
||||
|
||||
self.assertEqual(len(memory.recent_messages), 2)
|
||||
self.assertEqual(memory.recent_messages[0].content, "Apologies again!")
|
||||
self.assertEqual(memory.recent_messages[1].content, "Great job!")
|
||||
|
||||
def test_load_memory_variables_with_messages(self):
|
||||
memory = PersonaMemory()
|
||||
output_text = "Sorry about the mistake!"
|
||||
|
||||
memory.save_context({"input": "Hello"}, {"output": output_text})
|
||||
|
||||
result = memory.load_memory_variables({}, include_messages=True)
|
||||
self.assertIn("traits", result["persona"])
|
||||
self.assertIn("recent_messages", result["persona"])
|
||||
self.assertEqual(
|
||||
result["persona"]["recent_messages"][0]["content"], output_text
|
||||
)
|
||||
|
||||
def test_save_context_with_missing_output(self):
|
||||
memory = PersonaMemory()
|
||||
|
||||
memory.save_context({"input": "Hi"}, {}) # No output provided
|
||||
|
||||
self.assertEqual(len(memory.recent_messages), 1)
|
||||
self.assertEqual(
|
||||
memory.recent_messages[0].content, ""
|
||||
) # Should be empty string
|
||||
self.assertEqual(memory.recent_messages[0].traits, []) # No traits
|
||||
|
||||
def test_double_save_context_creates_two_entries(self):
|
||||
memory = PersonaMemory()
|
||||
|
||||
output_text = "Sorry for that mistake."
|
||||
memory.save_context({"input": "Hi"}, {"output": output_text})
|
||||
memory.save_context(
|
||||
{"input": "Hey"}, {"output": output_text}
|
||||
) # Same output again
|
||||
|
||||
self.assertEqual(len(memory.recent_messages), 2)
|
||||
self.assertEqual(memory.recent_messages[0].content, output_text)
|
||||
self.assertEqual(memory.recent_messages[1].content, output_text)
|
||||
|
||||
def test_full_conversation_simulation_with_failures(self):
|
||||
"""Simulate a longer conversation with mixed success and fallback during trait detection."""
|
||||
memory = PersonaMemory()
|
||||
|
||||
conversation = [
|
||||
("Hi there!", "I'm so happy to meet you!"), # Enthusiastic
|
||||
(
|
||||
"Could you help me?",
|
||||
"Of course! Happy to assist.",
|
||||
), # Friendly + Enthusiastic
|
||||
(
|
||||
"I'm unsure about this plan...",
|
||||
"Maybe we should rethink it.",
|
||||
), # Hesitant + Cautious (simulated failure here)
|
||||
(
|
||||
"Sorry about the delay.",
|
||||
"Apologies! It won't happen again.",
|
||||
), # Apologetic
|
||||
]
|
||||
|
||||
print("\n\n--- Starting full conversation simulation ---\n")
|
||||
|
||||
# Simulate that at message 3, external engine fails (mock failure)
|
||||
def dynamic_trait_detector(text: str):
|
||||
if "rethink" in text:
|
||||
print("Simulated API Failure during trait detection.\n")
|
||||
raise ValueError("Simulated API failure during cautious response")
|
||||
print(f"Simulated API Success - analyzing text: '{text}'\n")
|
||||
if "happy to meet you" in text:
|
||||
return {"enthusiastic": 1, "friendly": 1, "mocked_positive": 1}
|
||||
elif "Happy to assist" in text:
|
||||
return {"friendly": 1, "enthusiastic": 1, "mocked_positive": 1}
|
||||
elif "Apologies" in text:
|
||||
return {"apologetic": 1, "mocked_positive": 1}
|
||||
return {"mocked_positive": 1}
|
||||
|
||||
memory.trait_detection_engine = dynamic_trait_detector
|
||||
|
||||
for idx, (user_input, agent_output) in enumerate(conversation):
|
||||
print(f"--- Message {idx+1} ---")
|
||||
print(f"Input: {user_input}")
|
||||
print(f"Output: {agent_output}\n")
|
||||
memory.save_context({"input": user_input}, {"output": agent_output})
|
||||
|
||||
last_message = memory.recent_messages[-1]
|
||||
print(f"Saved Message Traits: {last_message.traits}\n")
|
||||
|
||||
print("--- Final accumulated traits ---")
|
||||
print(memory.persona_traits)
|
||||
print("\n--- End of conversation simulation ---\n")
|
||||
|
||||
# Assertions:
|
||||
# There should be 4 recent messages (conversation length)
|
||||
self.assertEqual(len(memory.recent_messages), 4)
|
||||
|
||||
# Traits should have accumulated normally, fallback triggered once
|
||||
detected_traits = set(memory.persona_traits)
|
||||
|
||||
expected_traits = {
|
||||
"enthusiastic",
|
||||
"friendly",
|
||||
"hesitant",
|
||||
"cautious",
|
||||
"apologetic",
|
||||
"mocked_positive",
|
||||
}
|
||||
|
||||
print(f"Detected traits: {detected_traits}")
|
||||
print(f"Expected traits: {expected_traits}")
|
||||
|
||||
for trait in expected_traits:
|
||||
self.assertIn(trait, detected_traits)
|
||||
|
||||
# Messages should have traits assigned
|
||||
for message in memory.recent_messages:
|
||||
self.assertIsInstance(message.traits, list)
|
||||
self.assertGreaterEqual(len(message.traits), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
"""
|
||||
Just the convo simulation:
|
||||
python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona.TestPersonaMemory.test_full_conversation_simulation_with_failure
|
||||
|
||||
All tests:
|
||||
python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona
|
||||
|
||||
or quitetly:
|
||||
python3 -m unittest -q libs.langchain.tests.unit_tests.memory.test_persona
|
||||
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user