This commit is contained in:
Mikolaj Mikuliszyn 2025-04-27 17:41:37 +01:00
parent a982d2f27d
commit 21c491a022

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
from pydantic import BaseModel, Field
import re
@ -10,8 +10,8 @@ class EnrichedMessage(BaseModel):
id: str
content: str
traits: List[str] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
traits: list[str] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
class PersonaMemory(BaseMemory):
@ -28,12 +28,12 @@ class PersonaMemory(BaseMemory):
input_key: str = "input"
output_key: str = "output"
k: int = 10
trait_detection_engine: Optional[Callable[[str], Dict[str, int]]] = None
persona_traits: List[str] = Field(default_factory=list)
recent_messages: List[Any] = Field(default_factory=list)
trait_detection_engine: Optional[Callable[[str], dict[str, int]]] = None
persona_traits: list[str] = Field(default_factory=list)
recent_messages: list[Any] = Field(default_factory=list)
@property
def memory_variables(self) -> List[str]:
def memory_variables(self) -> list[str]:
return [self.memory_key]
def clear(self) -> None:
@ -41,13 +41,12 @@ class PersonaMemory(BaseMemory):
self.persona_traits = []
self.recent_messages = []
def _detect_traits(self, text: str) -> Dict[str, int]:
def _detect_traits(self, text: str) -> dict[str, int]:
"""
Detect persona traits using both a default method and optionally an external engine.
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
Always guarantees a usable result, even if external services fail.
"""
# Always run the fast, local simple detection first
trait_patterns = {
"apologetic": ["sorry", "apologize", "apologies", "I apologize"],
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
@ -59,10 +58,10 @@ class PersonaMemory(BaseMemory):
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
}
trait_hits: Dict[str, int] = {}
trait_hits: dict[str, int] = {}
lowered_text = text.lower()
# Clean the text for word matching by removing punctuation
# Clean text for word matching
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
words = clean_text.split()
@ -79,27 +78,23 @@ class PersonaMemory(BaseMemory):
else:
if pattern_lower in lowered_text:
count += 1
if count > 0:
trait_hits[trait] = count
# Now attempt external engine if available
if self.trait_detection_engine:
try:
external_traits = self.trait_detection_engine(text)
if isinstance(external_traits, dict):
return external_traits
except Exception:
pass # Silently fall back to default detection
pass # Fallback to default detection
# Fall back to simple default detection if external fails or is unavailable
return trait_hits
def load_memory_variables(
self, inputs: Dict[str, Any], include_messages: bool = False
) -> Dict[str, Any]:
self, inputs: dict[str, Any], include_messages: bool = False
) -> dict[str, Any]:
"""Return the stored persona traits and optionally recent conversation messages."""
memory_data = {"traits": self.persona_traits.copy()}
if include_messages:
@ -109,12 +104,9 @@ class PersonaMemory(BaseMemory):
return {self.memory_key: memory_data}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
"""Analyze outputs and update the persona traits and recent messages."""
input_text = inputs.get(self.input_key, "")
output_text = outputs.get(self.output_key, "")
traits_detected = self._detect_traits(output_text)
message = EnrichedMessage(
@ -126,11 +118,10 @@ class PersonaMemory(BaseMemory):
self.recent_messages.append(message)
# Trim messages to maintain memory size k
# Trim recent messages to maintain memory size k
if len(self.recent_messages) > self.k:
self.recent_messages = self.recent_messages[-self.k :]
# Update persona traits
self.persona_traits = list(
{trait for msg in self.recent_messages for trait in msg.traits}
)