mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
add save/load methods for persona, more unit tests
Changes - Add save_context to PersonaMemory - Add load_memory_variables to PersonaMemory - Expand unit tests for memory save/load - Add full conversation simulation test - Maintain full backward compatibility with other modules (hopefully)
This commit is contained in:
parent
05a80bdf49
commit
a982d2f27d
@ -4,15 +4,25 @@ import re
|
|||||||
|
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
|
|
||||||
|
|
||||||
class EnrichedMessage(BaseModel):
|
class EnrichedMessage(BaseModel):
|
||||||
"""A message enriched with persona traits and metadata."""
|
"""A message enriched with persona traits and metadata."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
traits: List[str] = Field(default_factory=list)
|
traits: List[str] = Field(default_factory=list)
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PersonaMemory(BaseMemory):
|
class PersonaMemory(BaseMemory):
|
||||||
"""Memory that tracks evolving agent persona traits over a conversation."""
|
"""
|
||||||
|
Memory module that dynamically tracks emotional and behavioral traits
|
||||||
|
exhibited by an agent over the course of a conversation.
|
||||||
|
|
||||||
|
Traits are automatically detected from output messages and
|
||||||
|
accumulated across interactions. Recent messages and detected traits
|
||||||
|
can be retrieved to enrich future prompts or modify agent behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
memory_key: str = "persona"
|
memory_key: str = "persona"
|
||||||
input_key: str = "input"
|
input_key: str = "input"
|
||||||
@ -32,8 +42,9 @@ class PersonaMemory(BaseMemory):
|
|||||||
self.recent_messages = []
|
self.recent_messages = []
|
||||||
|
|
||||||
def _detect_traits(self, text: str) -> Dict[str, int]:
|
def _detect_traits(self, text: str) -> Dict[str, int]:
|
||||||
"""Detect persona traits using both a default method and optionally an external engine.
|
"""
|
||||||
|
Detect persona traits using both a default method and optionally an external engine.
|
||||||
|
|
||||||
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
|
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
|
||||||
"""
|
"""
|
||||||
# Always run the fast, local simple detection first
|
# Always run the fast, local simple detection first
|
||||||
@ -42,16 +53,17 @@ class PersonaMemory(BaseMemory):
|
|||||||
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
||||||
"formal": ["Dear", "Sincerely", "Respectfully"],
|
"formal": ["Dear", "Sincerely", "Respectfully"],
|
||||||
"cautious": ["maybe", "perhaps", "I think", "it could be"],
|
"cautious": ["maybe", "perhaps", "I think", "it could be"],
|
||||||
"friendly": ["Hi", "Hey", "Hello", "Good to see you"],
|
"hesitant": ["maybe", "...", "I'm not sure", "perhaps", "unsure"],
|
||||||
|
"friendly": ["Hi", "Hey", "Hello", "Good to see you", "friendly"],
|
||||||
"curious": ["?", "I wonder", "Could you"],
|
"curious": ["?", "I wonder", "Could you"],
|
||||||
"hesitant": ["...", "I'm not sure", "perhaps"],
|
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
|
||||||
}
|
}
|
||||||
|
|
||||||
trait_hits: Dict[str, int] = {}
|
trait_hits: Dict[str, int] = {}
|
||||||
lowered_text = text.lower()
|
lowered_text = text.lower()
|
||||||
|
|
||||||
# Clean the text for word matching by removing punctuation
|
# Clean the text for word matching by removing punctuation
|
||||||
clean_text = re.sub(r'[^\w\s]', ' ', lowered_text)
|
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
|
||||||
words = clean_text.split()
|
words = clean_text.split()
|
||||||
|
|
||||||
for trait, patterns in trait_patterns.items():
|
for trait, patterns in trait_patterns.items():
|
||||||
@ -59,8 +71,7 @@ class PersonaMemory(BaseMemory):
|
|||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
pattern_lower = pattern.lower()
|
pattern_lower = pattern.lower()
|
||||||
if trait == "friendly":
|
if trait == "friendly":
|
||||||
if (lowered_text.startswith(pattern_lower) or
|
if lowered_text.startswith(pattern_lower) or pattern_lower in words:
|
||||||
pattern_lower in words):
|
|
||||||
count += 1
|
count += 1
|
||||||
elif trait == "enthusiastic":
|
elif trait == "enthusiastic":
|
||||||
if pattern == "!":
|
if pattern == "!":
|
||||||
@ -68,7 +79,7 @@ class PersonaMemory(BaseMemory):
|
|||||||
else:
|
else:
|
||||||
if pattern_lower in lowered_text:
|
if pattern_lower in lowered_text:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
trait_hits[trait] = count
|
trait_hits[trait] = count
|
||||||
|
|
||||||
@ -83,3 +94,43 @@ class PersonaMemory(BaseMemory):
|
|||||||
|
|
||||||
# Fall back to simple default detection if external fails or is unavailable
|
# Fall back to simple default detection if external fails or is unavailable
|
||||||
return trait_hits
|
return trait_hits
|
||||||
|
|
||||||
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], include_messages: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Return the stored persona traits and optionally recent conversation messages."""
|
||||||
|
|
||||||
|
memory_data = {"traits": self.persona_traits.copy()}
|
||||||
|
|
||||||
|
if include_messages:
|
||||||
|
memory_data["recent_messages"] = [
|
||||||
|
message.model_dump() for message in self.recent_messages
|
||||||
|
]
|
||||||
|
|
||||||
|
return {self.memory_key: memory_data}
|
||||||
|
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
"""Analyze outputs and update the persona traits and recent messages."""
|
||||||
|
|
||||||
|
input_text = inputs.get(self.input_key, "")
|
||||||
|
output_text = outputs.get(self.output_key, "")
|
||||||
|
|
||||||
|
traits_detected = self._detect_traits(output_text)
|
||||||
|
|
||||||
|
message = EnrichedMessage(
|
||||||
|
id=str(len(self.recent_messages) + 1),
|
||||||
|
content=output_text,
|
||||||
|
traits=list(traits_detected.keys()),
|
||||||
|
metadata={"traits_count": traits_detected},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.recent_messages.append(message)
|
||||||
|
|
||||||
|
# Trim messages to maintain memory size k
|
||||||
|
if len(self.recent_messages) > self.k:
|
||||||
|
self.recent_messages = self.recent_messages[-self.k :]
|
||||||
|
|
||||||
|
# Update persona traits
|
||||||
|
self.persona_traits = list(
|
||||||
|
{trait for msg in self.recent_messages for trait in msg.traits}
|
||||||
|
)
|
||||||
|
@ -7,31 +7,27 @@ from langchain.schema import AIMessage, HumanMessage, BaseMessage
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def capture_output():
|
|
||||||
"""Capture stdout and stderr."""
|
|
||||||
new_out, new_err = StringIO(), StringIO()
|
|
||||||
old_out, old_err = sys.stdout, sys.stderr
|
|
||||||
try:
|
|
||||||
sys.stdout, sys.stderr = new_out, new_err
|
|
||||||
yield sys.stdout, sys.stderr
|
|
||||||
finally:
|
|
||||||
sys.stdout, sys.stderr = old_out, old_err
|
|
||||||
|
|
||||||
class MockChatModel:
|
class MockChatModel:
|
||||||
"""A simple mock chat model that can simulate both success and failure cases."""
|
"""A simple mock chat model that can simulate both success and failure cases."""
|
||||||
|
|
||||||
def __init__(self, response: Optional[str] = None, should_fail: bool = False, error_msg: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
response: Optional[str] = None,
|
||||||
|
should_fail: bool = False,
|
||||||
|
error_msg: str = "",
|
||||||
|
):
|
||||||
self.response = response
|
self.response = response
|
||||||
self.should_fail = should_fail
|
self.should_fail = should_fail
|
||||||
self.error_msg = error_msg
|
self.error_msg = error_msg
|
||||||
|
|
||||||
def invoke(self, messages: List[BaseMessage], **kwargs) -> AIMessage:
|
def invoke(self, messages: List[BaseMessage], **kwargs) -> AIMessage:
|
||||||
"""Simulate chat model invocation."""
|
"""Simulate chat model invocation."""
|
||||||
if self.should_fail:
|
if self.should_fail:
|
||||||
raise ValueError(self.error_msg)
|
raise ValueError(self.error_msg)
|
||||||
return AIMessage(content=self.response)
|
return AIMessage(content=self.response)
|
||||||
|
|
||||||
|
|
||||||
class TestPersonaMemory(unittest.TestCase):
|
class TestPersonaMemory(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test-level attributes."""
|
"""Set up test-level attributes."""
|
||||||
@ -40,13 +36,9 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
|
|
||||||
def _create_successful_chat_model(self):
|
def _create_successful_chat_model(self):
|
||||||
"""Create a mock chat model that returns successful responses."""
|
"""Create a mock chat model that returns successful responses."""
|
||||||
response = json.dumps({
|
response = json.dumps(
|
||||||
"traits": {
|
{"traits": {"confident": 2, "analytical": 1, "professional": 1}}
|
||||||
"confident": 2,
|
)
|
||||||
"analytical": 1,
|
|
||||||
"professional": 1
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return MockChatModel(response=response)
|
return MockChatModel(response=response)
|
||||||
|
|
||||||
def _create_failing_chat_model(self, error_type: str = "rate_limit"):
|
def _create_failing_chat_model(self, error_type: str = "rate_limit"):
|
||||||
@ -57,14 +49,11 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
error_msg = "Invalid API key"
|
error_msg = "Invalid API key"
|
||||||
else:
|
else:
|
||||||
error_msg = "Internal server error"
|
error_msg = "Internal server error"
|
||||||
|
|
||||||
return MockChatModel(should_fail=True, error_msg=error_msg)
|
return MockChatModel(should_fail=True, error_msg=error_msg)
|
||||||
|
|
||||||
def test_enriched_message_creation(self):
|
def test_enriched_message_creation(self):
|
||||||
input_data = {
|
input_data = {"id": "msg-001", "content": "Hello!"}
|
||||||
"id": "msg-001",
|
|
||||||
"content": "Hello!"
|
|
||||||
}
|
|
||||||
message = EnrichedMessage(**input_data)
|
message = EnrichedMessage(**input_data)
|
||||||
self.assertEqual(message.id, "msg-001")
|
self.assertEqual(message.id, "msg-001")
|
||||||
self.assertEqual(message.content, "Hello!")
|
self.assertEqual(message.content, "Hello!")
|
||||||
@ -129,7 +118,7 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
memory = PersonaMemory()
|
memory = PersonaMemory()
|
||||||
output_texts = [
|
output_texts = [
|
||||||
"I'm sorry, truly sorry about the confusion.",
|
"I'm sorry, truly sorry about the confusion.",
|
||||||
"Hello there! Great job on that task!"
|
"Hello there! Great job on that task!",
|
||||||
]
|
]
|
||||||
|
|
||||||
for text in output_texts:
|
for text in output_texts:
|
||||||
@ -153,12 +142,12 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_successful_api_detection(self):
|
def test_successful_api_detection(self):
|
||||||
chat_model = self._create_successful_chat_model()
|
chat_model = self._create_successful_chat_model()
|
||||||
|
|
||||||
def trait_detector(text: str):
|
def trait_detector(text: str):
|
||||||
messages = [HumanMessage(content=text)]
|
messages = [HumanMessage(content=text)]
|
||||||
response = chat_model.invoke(messages)
|
response = chat_model.invoke(messages)
|
||||||
return json.loads(response.content)["traits"]
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
text = "Based on the analysis, I can confidently say this is the best approach."
|
text = "Based on the analysis, I can confidently say this is the best approach."
|
||||||
traits = memory._detect_traits(text)
|
traits = memory._detect_traits(text)
|
||||||
@ -171,12 +160,12 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_rate_limit_failure(self):
|
def test_rate_limit_failure(self):
|
||||||
chat_model = self._create_failing_chat_model("rate_limit")
|
chat_model = self._create_failing_chat_model("rate_limit")
|
||||||
|
|
||||||
def trait_detector(text: str):
|
def trait_detector(text: str):
|
||||||
messages = [HumanMessage(content=text)]
|
messages = [HumanMessage(content=text)]
|
||||||
response = chat_model.invoke(messages)
|
response = chat_model.invoke(messages)
|
||||||
return json.loads(response.content)["traits"]
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
text = "I'm so sorry about the confusion."
|
text = "I'm so sorry about the confusion."
|
||||||
traits = memory._detect_traits(text)
|
traits = memory._detect_traits(text)
|
||||||
@ -184,12 +173,12 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_authentication_failure(self):
|
def test_authentication_failure(self):
|
||||||
chat_model = self._create_failing_chat_model("authentication")
|
chat_model = self._create_failing_chat_model("authentication")
|
||||||
|
|
||||||
def trait_detector(text: str):
|
def trait_detector(text: str):
|
||||||
messages = [HumanMessage(content=text)]
|
messages = [HumanMessage(content=text)]
|
||||||
response = chat_model.invoke(messages)
|
response = chat_model.invoke(messages)
|
||||||
return json.loads(response.content)["traits"]
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
text = "I apologize for the mistake."
|
text = "I apologize for the mistake."
|
||||||
traits = memory._detect_traits(text)
|
traits = memory._detect_traits(text)
|
||||||
@ -197,16 +186,189 @@ class TestPersonaMemory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_server_error_failure(self):
|
def test_server_error_failure(self):
|
||||||
chat_model = self._create_failing_chat_model("server_error")
|
chat_model = self._create_failing_chat_model("server_error")
|
||||||
|
|
||||||
def trait_detector(text: str):
|
def trait_detector(text: str):
|
||||||
messages = [HumanMessage(content=text)]
|
messages = [HumanMessage(content=text)]
|
||||||
response = chat_model.invoke(messages)
|
response = chat_model.invoke(messages)
|
||||||
return json.loads(response.content)["traits"]
|
return json.loads(response.content)["traits"]
|
||||||
|
|
||||||
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
memory = PersonaMemory(trait_detection_engine=trait_detector)
|
||||||
text = "I'm very sorry about that."
|
text = "I'm very sorry about that."
|
||||||
traits = memory._detect_traits(text)
|
traits = memory._detect_traits(text)
|
||||||
self.assertIn("apologetic", traits) # Should fall back to default detection
|
self.assertIn("apologetic", traits) # Should fall back to default detection
|
||||||
|
|
||||||
|
def test_save_context_updates_memory(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
input_text = "Hello!"
|
||||||
|
output_text = "I'm sorry for the confusion."
|
||||||
|
|
||||||
|
memory.save_context({"input": input_text}, {"output": output_text})
|
||||||
|
self.assertIn("apologetic", memory.persona_traits)
|
||||||
|
self.assertEqual(len(memory.recent_messages), 1)
|
||||||
|
self.assertEqual(memory.recent_messages[0].content, output_text)
|
||||||
|
self.assertIn("apologetic", memory.recent_messages[0].traits)
|
||||||
|
|
||||||
|
def test_load_memory_variables_returns_traits(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
memory.persona_traits = ["friendly", "apologetic"]
|
||||||
|
|
||||||
|
result = memory.load_memory_variables({})
|
||||||
|
self.assertEqual(result["persona"], {"traits": ["friendly", "apologetic"]})
|
||||||
|
|
||||||
|
def test_save_context_updates_memory_with_multiple_traits(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
input_text = "Hello!"
|
||||||
|
output_text = "I'm sorry for the confusion. I'm also friendly and analytical."
|
||||||
|
|
||||||
|
memory.save_context({"input": input_text}, {"output": output_text})
|
||||||
|
self.assertIn("apologetic", memory.persona_traits)
|
||||||
|
self.assertIn("friendly", memory.persona_traits)
|
||||||
|
self.assertIn("analytical", memory.persona_traits)
|
||||||
|
self.assertEqual(len(memory.recent_messages), 1)
|
||||||
|
self.assertEqual(memory.recent_messages[0].content, output_text)
|
||||||
|
self.assertIn("apologetic", memory.recent_messages[0].traits)
|
||||||
|
self.assertIn("friendly", memory.recent_messages[0].traits)
|
||||||
|
self.assertIn("analytical", memory.recent_messages[0].traits)
|
||||||
|
|
||||||
|
def test_memory_trims_to_k_messages(self):
|
||||||
|
memory = PersonaMemory(k=2)
|
||||||
|
|
||||||
|
memory.save_context({"input": "Hi"}, {"output": "Sorry about that!"})
|
||||||
|
memory.save_context({"input": "Hello"}, {"output": "Apologies again!"})
|
||||||
|
memory.save_context({"input": "Hey"}, {"output": "Great job!"}) # Third message
|
||||||
|
|
||||||
|
self.assertEqual(len(memory.recent_messages), 2)
|
||||||
|
self.assertEqual(memory.recent_messages[0].content, "Apologies again!")
|
||||||
|
self.assertEqual(memory.recent_messages[1].content, "Great job!")
|
||||||
|
|
||||||
|
def test_load_memory_variables_with_messages(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
output_text = "Sorry about the mistake!"
|
||||||
|
|
||||||
|
memory.save_context({"input": "Hello"}, {"output": output_text})
|
||||||
|
|
||||||
|
result = memory.load_memory_variables({}, include_messages=True)
|
||||||
|
self.assertIn("traits", result["persona"])
|
||||||
|
self.assertIn("recent_messages", result["persona"])
|
||||||
|
self.assertEqual(
|
||||||
|
result["persona"]["recent_messages"][0]["content"], output_text
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_save_context_with_missing_output(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
|
||||||
|
memory.save_context({"input": "Hi"}, {}) # No output provided
|
||||||
|
|
||||||
|
self.assertEqual(len(memory.recent_messages), 1)
|
||||||
|
self.assertEqual(
|
||||||
|
memory.recent_messages[0].content, ""
|
||||||
|
) # Should be empty string
|
||||||
|
self.assertEqual(memory.recent_messages[0].traits, []) # No traits
|
||||||
|
|
||||||
|
def test_double_save_context_creates_two_entries(self):
|
||||||
|
memory = PersonaMemory()
|
||||||
|
|
||||||
|
output_text = "Sorry for that mistake."
|
||||||
|
memory.save_context({"input": "Hi"}, {"output": output_text})
|
||||||
|
memory.save_context(
|
||||||
|
{"input": "Hey"}, {"output": output_text}
|
||||||
|
) # Same output again
|
||||||
|
|
||||||
|
self.assertEqual(len(memory.recent_messages), 2)
|
||||||
|
self.assertEqual(memory.recent_messages[0].content, output_text)
|
||||||
|
self.assertEqual(memory.recent_messages[1].content, output_text)
|
||||||
|
|
||||||
|
def test_full_conversation_simulation_with_failures(self):
|
||||||
|
"""Simulate a longer conversation with mixed success and fallback during trait detection."""
|
||||||
|
memory = PersonaMemory()
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
("Hi there!", "I'm so happy to meet you!"), # Enthusiastic
|
||||||
|
(
|
||||||
|
"Could you help me?",
|
||||||
|
"Of course! Happy to assist.",
|
||||||
|
), # Friendly + Enthusiastic
|
||||||
|
(
|
||||||
|
"I'm unsure about this plan...",
|
||||||
|
"Maybe we should rethink it.",
|
||||||
|
), # Hesitant + Cautious (simulated failure here)
|
||||||
|
(
|
||||||
|
"Sorry about the delay.",
|
||||||
|
"Apologies! It won't happen again.",
|
||||||
|
), # Apologetic
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n\n--- Starting full conversation simulation ---\n")
|
||||||
|
|
||||||
|
# Simulate that at message 3, external engine fails (mock failure)
|
||||||
|
def dynamic_trait_detector(text: str):
|
||||||
|
if "rethink" in text:
|
||||||
|
print("Simulated API Failure during trait detection.\n")
|
||||||
|
raise ValueError("Simulated API failure during cautious response")
|
||||||
|
print(f"Simulated API Success - analyzing text: '{text}'\n")
|
||||||
|
if "happy to meet you" in text:
|
||||||
|
return {"enthusiastic": 1, "friendly": 1, "mocked_positive": 1}
|
||||||
|
elif "Happy to assist" in text:
|
||||||
|
return {"friendly": 1, "enthusiastic": 1, "mocked_positive": 1}
|
||||||
|
elif "Apologies" in text:
|
||||||
|
return {"apologetic": 1, "mocked_positive": 1}
|
||||||
|
return {"mocked_positive": 1}
|
||||||
|
|
||||||
|
memory.trait_detection_engine = dynamic_trait_detector
|
||||||
|
|
||||||
|
for idx, (user_input, agent_output) in enumerate(conversation):
|
||||||
|
print(f"--- Message {idx+1} ---")
|
||||||
|
print(f"Input: {user_input}")
|
||||||
|
print(f"Output: {agent_output}\n")
|
||||||
|
memory.save_context({"input": user_input}, {"output": agent_output})
|
||||||
|
|
||||||
|
last_message = memory.recent_messages[-1]
|
||||||
|
print(f"Saved Message Traits: {last_message.traits}\n")
|
||||||
|
|
||||||
|
print("--- Final accumulated traits ---")
|
||||||
|
print(memory.persona_traits)
|
||||||
|
print("\n--- End of conversation simulation ---\n")
|
||||||
|
|
||||||
|
# Assertions:
|
||||||
|
# There should be 4 recent messages (conversation length)
|
||||||
|
self.assertEqual(len(memory.recent_messages), 4)
|
||||||
|
|
||||||
|
# Traits should have accumulated normally, fallback triggered once
|
||||||
|
detected_traits = set(memory.persona_traits)
|
||||||
|
|
||||||
|
expected_traits = {
|
||||||
|
"enthusiastic",
|
||||||
|
"friendly",
|
||||||
|
"hesitant",
|
||||||
|
"cautious",
|
||||||
|
"apologetic",
|
||||||
|
"mocked_positive",
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Detected traits: {detected_traits}")
|
||||||
|
print(f"Expected traits: {expected_traits}")
|
||||||
|
|
||||||
|
for trait in expected_traits:
|
||||||
|
self.assertIn(trait, detected_traits)
|
||||||
|
|
||||||
|
# Messages should have traits assigned
|
||||||
|
for message in memory.recent_messages:
|
||||||
|
self.assertIsInstance(message.traits, list)
|
||||||
|
self.assertGreaterEqual(len(message.traits), 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Just the convo simulation:
|
||||||
|
python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona.TestPersonaMemory.test_full_conversation_simulation_with_failure
|
||||||
|
|
||||||
|
All tests:
|
||||||
|
python3 -m unittest -v libs.langchain.tests.unit_tests.memory.test_persona
|
||||||
|
|
||||||
|
or quitetly:
|
||||||
|
python3 -m unittest -q libs.langchain.tests.unit_tests.memory.test_persona
|
||||||
|
|
||||||
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user