mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
initial PersonaMemory class with trait detection
This PR introduces the foundation of a PersonaMemory class for LangChain. It tracks evolving persona traits detected from model outputs, using either simple rule-based matching or an external LLM engine. Core methods (save_context, load_memory_variables) are stubbed out for now and will be completed in a follow-up PR. Full internal trait detection and fallback behaviour is already tested. Future work includes integrating persona tracking with broader memory chains and serialisation formats.
This commit is contained in:
parent
04a899ebe3
commit
05a80bdf49
85
libs/langchain/langchain/memory/persona.py
Normal file
85
libs/langchain/langchain/memory/persona.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import re
|
||||||
|
|
||||||
|
from langchain_core.memory import BaseMemory
|
||||||
|
|
||||||
|
class EnrichedMessage(BaseModel):
|
||||||
|
"""A message enriched with persona traits and metadata."""
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
traits: List[str] = Field(default_factory=list)
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class PersonaMemory(BaseMemory):
|
||||||
|
"""Memory that tracks evolving agent persona traits over a conversation."""
|
||||||
|
|
||||||
|
memory_key: str = "persona"
|
||||||
|
input_key: str = "input"
|
||||||
|
output_key: str = "output"
|
||||||
|
k: int = 10
|
||||||
|
trait_detection_engine: Optional[Callable[[str], Dict[str, int]]] = None
|
||||||
|
persona_traits: List[str] = Field(default_factory=list)
|
||||||
|
recent_messages: List[Any] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_variables(self) -> List[str]:
|
||||||
|
return [self.memory_key]
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the memory state."""
|
||||||
|
self.persona_traits = []
|
||||||
|
self.recent_messages = []
|
||||||
|
|
||||||
|
def _detect_traits(self, text: str) -> Dict[str, int]:
|
||||||
|
"""Detect persona traits using both a default method and optionally an external engine.
|
||||||
|
|
||||||
|
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
|
||||||
|
"""
|
||||||
|
# Always run the fast, local simple detection first
|
||||||
|
trait_patterns = {
|
||||||
|
"apologetic": ["sorry", "apologize", "apologies", "I apologize"],
|
||||||
|
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
||||||
|
"formal": ["Dear", "Sincerely", "Respectfully"],
|
||||||
|
"cautious": ["maybe", "perhaps", "I think", "it could be"],
|
||||||
|
"friendly": ["Hi", "Hey", "Hello", "Good to see you"],
|
||||||
|
"curious": ["?", "I wonder", "Could you"],
|
||||||
|
"hesitant": ["...", "I'm not sure", "perhaps"],
|
||||||
|
}
|
||||||
|
|
||||||
|
trait_hits: Dict[str, int] = {}
|
||||||
|
lowered_text = text.lower()
|
||||||
|
|
||||||
|
# Clean the text for word matching by removing punctuation
|
||||||
|
clean_text = re.sub(r'[^\w\s]', ' ', lowered_text)
|
||||||
|
words = clean_text.split()
|
||||||
|
|
||||||
|
for trait, patterns in trait_patterns.items():
|
||||||
|
count = 0
|
||||||
|
for pattern in patterns:
|
||||||
|
pattern_lower = pattern.lower()
|
||||||
|
if trait == "friendly":
|
||||||
|
if (lowered_text.startswith(pattern_lower) or
|
||||||
|
pattern_lower in words):
|
||||||
|
count += 1
|
||||||
|
elif trait == "enthusiastic":
|
||||||
|
if pattern == "!":
|
||||||
|
count = lowered_text.count("!")
|
||||||
|
else:
|
||||||
|
if pattern_lower in lowered_text:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
trait_hits[trait] = count
|
||||||
|
|
||||||
|
# Now attempt external engine if available
|
||||||
|
if self.trait_detection_engine:
|
||||||
|
try:
|
||||||
|
external_traits = self.trait_detection_engine(text)
|
||||||
|
if isinstance(external_traits, dict):
|
||||||
|
return external_traits
|
||||||
|
except Exception:
|
||||||
|
pass # Silently fall back to default detection
|
||||||
|
|
||||||
|
# Fall back to simple default detection if external fails or is unavailable
|
||||||
|
return trait_hits
|
212
libs/langchain/tests/unit_tests/memory/test_persona.py
Normal file
212
libs/langchain/tests/unit_tests/memory/test_persona.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
from io import StringIO
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from langchain.memory.persona import PersonaMemory, EnrichedMessage
|
||||||
|
from langchain.schema import AIMessage, HumanMessage, BaseMessage
|
||||||
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def capture_output():
|
||||||
|
"""Capture stdout and stderr."""
|
||||||
|
new_out, new_err = StringIO(), StringIO()
|
||||||
|
old_out, old_err = sys.stdout, sys.stderr
|
||||||
|
try:
|
||||||
|
sys.stdout, sys.stderr = new_out, new_err
|
||||||
|
yield sys.stdout, sys.stderr
|
||||||
|
finally:
|
||||||
|
sys.stdout, sys.stderr = old_out, old_err
|
||||||
|
|
||||||
|
class MockChatModel:
|
||||||
|
"""A simple mock chat model that can simulate both success and failure cases."""
|
||||||
|
|
||||||
|
def __init__(self, response: Optional[str] = None, should_fail: bool = False, error_msg: str = ""):
|
||||||
|
self.response = response
|
||||||
|
self.should_fail = should_fail
|
||||||
|
self.error_msg = error_msg
|
||||||
|
|
||||||
|
def invoke(self, messages: List[BaseMessage], **kwargs) -> AIMessage:
|
||||||
|
"""Simulate chat model invocation."""
|
||||||
|
if self.should_fail:
|
||||||
|
raise ValueError(self.error_msg)
|
||||||
|
return AIMessage(content=self.response)
|
||||||
|
|
||||||
|
class TestPersonaMemory(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test-level attributes."""
|
||||||
|
self.maxDiff = None
|
||||||
|
unittest.TestResult.showAll = True
|
||||||
|
|
||||||
|
def _create_successful_chat_model(self):
|
||||||
|
"""Create a mock chat model that returns successful responses."""
|
||||||
|
response = json.dumps({
|
||||||
|
"traits": {
|
||||||
|
"confident": 2,
|
||||||
|
"analytical": 1,
|
||||||
|
"professional": 1
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return MockChatModel(response=response)
|
||||||
|
|
||||||
|
def _create_failing_chat_model(self, error_type: str = "rate_limit"):
|
||||||
|
"""Create a mock chat model that simulates failures."""
|
||||||
|
if error_type == "rate_limit":
|
||||||
|
error_msg = "Rate limit exceeded"
|
||||||
|
elif error_type == "authentication":
|
||||||
|
error_msg = "Invalid API key"
|
||||||
|
else:
|
||||||
|
error_msg = "Internal server error"
|
||||||
|
|
||||||
|
return MockChatModel(should_fail=True, error_msg=error_msg)
|
||||||
|
|
||||||
|
def test_enriched_message_creation(self):
|
||||||
|
input_data = {
|
||||||
|
"id": "msg-001",
|
||||||
|
"content": "Hello!"
|
||||||
|
}
|
||||||
|
message = EnrichedMessage(**input_data)
|
||||||
|
self.assertEqual(message.id, "msg-001")
|
||||||
|
self.assertEqual(message.content, "Hello!")
|
||||||
|
self.assertEqual(message.traits, [])
|
||||||
|
self.assertEqual(message.metadata, {})
|
||||||
|
|
||||||
|
def test_persona_memory_initialization(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
self.assertEqual(memory.memory_key, "persona")
|
||||||
|
self.assertEqual(memory.input_key, "input")
|
||||||
|
self.assertEqual(memory.output_key, "output")
|
||||||
|
self.assertEqual(memory.k, 10)
|
||||||
|
self.assertEqual(memory.persona_traits, [])
|
||||||
|
self.assertEqual(memory.recent_messages, [])
|
||||||
|
|
||||||
|
def test_memory_variables_property(self):
|
||||||
|
custom_key = "custom_memory_key"
|
||||||
|
memory = PersonaMemory(memory_key=custom_key)
|
||||||
|
result = memory.memory_variables
|
||||||
|
self.assertEqual(result, [custom_key])
|
||||||
|
|
||||||
|
def test_default_trait_detection_simple_text(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
text = "I'm so sorry again! Hello!"
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("apologetic", traits)
|
||||||
|
self.assertIn("friendly", traits)
|
||||||
|
self.assertIn("enthusiastic", traits)
|
||||||
|
self.assertEqual(traits["apologetic"], 1)
|
||||||
|
self.assertEqual(traits["friendly"], 1)
|
||||||
|
self.assertEqual(traits["enthusiastic"], 2)
|
||||||
|
|
||||||
|
def test_default_trait_detection_no_traits(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
text = "This sentence has no emotional clues."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertEqual(traits, {})
|
||||||
|
|
||||||
|
def test_fallback_to_default_on_engine_failure(self):
|
||||||
|
def broken_detector(text: str):
|
||||||
|
raise ValueError("Simulated engine failure")
|
||||||
|
|
||||||
|
memory = PersonaMemory(trait_detection_engine=broken_detector)
|
||||||
|
text = "Hey, sorry about that."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("apologetic", traits)
|
||||||
|
self.assertIn("friendly", traits)
|
||||||
|
self.assertEqual(traits["apologetic"], 1)
|
||||||
|
self.assertEqual(traits["friendly"], 1)
|
||||||
|
|
||||||
|
def test_external_trait_detection_mock(self):
|
||||||
|
def mock_trait_detector(text: str):
|
||||||
|
return {"mocked_trait": 2}
|
||||||
|
|
||||||
|
memory = PersonaMemory(trait_detection_engine=mock_trait_detector)
|
||||||
|
text = "This is a mock test."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("mocked_trait", traits)
|
||||||
|
self.assertEqual(traits["mocked_trait"], 2)
|
||||||
|
|
||||||
|
def test_multiple_outputs_accumulate_traits(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
output_texts = [
|
||||||
|
"I'm sorry, truly sorry about the confusion.",
|
||||||
|
"Hello there! Great job on that task!"
|
||||||
|
]
|
||||||
|
|
||||||
|
for text in output_texts:
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
memory.persona_traits.extend(list(traits.keys()))
|
||||||
|
|
||||||
|
self.assertIn("apologetic", memory.persona_traits)
|
||||||
|
self.assertIn("friendly", memory.persona_traits)
|
||||||
|
self.assertIn("enthusiastic", memory.persona_traits)
|
||||||
|
self.assertEqual(len(memory.persona_traits), 3)
|
||||||
|
|
||||||
|
def test_clear_method_resets_memory(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
initial_traits = ["friendly", "apologetic"]
|
||||||
|
initial_messages = ["dummy_message"]
|
||||||
|
memory.persona_traits = initial_traits
|
||||||
|
memory.recent_messages = initial_messages
|
||||||
|
memory.clear()
|
||||||
|
self.assertEqual(memory.persona_traits, [])
|
||||||
|
self.assertEqual(memory.recent_messages, [])
|
||||||
|
|
||||||
|
def test_successful_api_detection(self):
|
||||||
|
chat_model = self._create_successful_chat_model()
|
||||||
|
|
||||||
|
def trait_detector(text: str):
|
||||||
|
messages = [HumanMessage(content=text)]
|
||||||
|
response = chat_model.invoke(messages)
|
||||||
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
|
text = "Based on the analysis, I can confidently say this is the best approach."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("confident", traits)
|
||||||
|
self.assertIn("analytical", traits)
|
||||||
|
self.assertIn("professional", traits)
|
||||||
|
self.assertEqual(traits["confident"], 2)
|
||||||
|
self.assertEqual(traits["analytical"], 1)
|
||||||
|
self.assertEqual(traits["professional"], 1)
|
||||||
|
|
||||||
|
def test_rate_limit_failure(self):
|
||||||
|
chat_model = self._create_failing_chat_model("rate_limit")
|
||||||
|
|
||||||
|
def trait_detector(text: str):
|
||||||
|
messages = [HumanMessage(content=text)]
|
||||||
|
response = chat_model.invoke(messages)
|
||||||
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
|
text = "I'm so sorry about the confusion."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("apologetic", traits) # Should fall back to default detection
|
||||||
|
|
||||||
|
def test_authentication_failure(self):
|
||||||
|
chat_model = self._create_failing_chat_model("authentication")
|
||||||
|
|
||||||
|
def trait_detector(text: str):
|
||||||
|
messages = [HumanMessage(content=text)]
|
||||||
|
response = chat_model.invoke(messages)
|
||||||
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
|
text = "I apologize for the mistake."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("apologetic", traits) # Should fall back to default detection
|
||||||
|
|
||||||
|
def test_server_error_failure(self):
|
||||||
|
chat_model = self._create_failing_chat_model("server_error")
|
||||||
|
|
||||||
|
def trait_detector(text: str):
|
||||||
|
messages = [HumanMessage(content=text)]
|
||||||
|
response = chat_model.invoke(messages)
|
||||||
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
|
text = "I'm very sorry about that."
|
||||||
|
traits = memory._detect_traits(text)
|
||||||
|
self.assertIn("apologetic", traits) # Should fall back to default detection
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user