From 21c491a02246d87866d05ee991eb2574059a681a Mon Sep 17 00:00:00 2001 From: Mikolaj Mikuliszyn Date: Sun, 27 Apr 2025 17:41:37 +0100 Subject: [PATCH] mend --- libs/langchain/langchain/memory/persona.py | 41 +++++++++------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/libs/langchain/langchain/memory/persona.py b/libs/langchain/langchain/memory/persona.py index bfc4fb613af..114c49ef56c 100644 --- a/libs/langchain/langchain/memory/persona.py +++ b/libs/langchain/langchain/memory/persona.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional from pydantic import BaseModel, Field import re @@ -10,8 +10,8 @@ class EnrichedMessage(BaseModel): id: str content: str - traits: List[str] = Field(default_factory=list) - metadata: Dict[str, Any] = Field(default_factory=dict) + traits: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) class PersonaMemory(BaseMemory): @@ -28,12 +28,12 @@ class PersonaMemory(BaseMemory): input_key: str = "input" output_key: str = "output" k: int = 10 - trait_detection_engine: Optional[Callable[[str], Dict[str, int]]] = None - persona_traits: List[str] = Field(default_factory=list) - recent_messages: List[Any] = Field(default_factory=list) + trait_detection_engine: Optional[Callable[[str], dict[str, int]]] = None + persona_traits: list[str] = Field(default_factory=list) + recent_messages: list[Any] = Field(default_factory=list) @property - def memory_variables(self) -> List[str]: + def memory_variables(self) -> list[str]: return [self.memory_key] def clear(self) -> None: @@ -41,13 +41,12 @@ class PersonaMemory(BaseMemory): self.persona_traits = [] self.recent_messages = [] - def _detect_traits(self, text: str) -> Dict[str, int]: + def _detect_traits(self, text: str) -> dict[str, int]: """ Detect persona traits using both a default method and optionally an external engine. - Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded) + Always guarantees a usable result, even if external services fail. """ - # Always run the fast, local simple detection first trait_patterns = { "apologetic": ["sorry", "apologize", "apologies", "I apologize"], "enthusiastic": ["!", "awesome", "great job", "fantastic"], @@ -59,10 +58,10 @@ class PersonaMemory(BaseMemory): "analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"], } - trait_hits: Dict[str, int] = {} + trait_hits: dict[str, int] = {} lowered_text = text.lower() - # Clean the text for word matching by removing punctuation + # Clean text for word matching clean_text = re.sub(r"[^\w\s]", " ", lowered_text) words = clean_text.split() @@ -79,27 +78,23 @@ class PersonaMemory(BaseMemory): else: if pattern_lower in lowered_text: count += 1 - if count > 0: trait_hits[trait] = count - # Now attempt external engine if available if self.trait_detection_engine: try: external_traits = self.trait_detection_engine(text) if isinstance(external_traits, dict): return external_traits except Exception: - pass # Silently fall back to default detection + pass # Fallback to default detection - # Fall back to simple default detection if external fails or is unavailable return trait_hits def load_memory_variables( - self, inputs: Dict[str, Any], include_messages: bool = False - ) -> Dict[str, Any]: + self, inputs: dict[str, Any], include_messages: bool = False + ) -> dict[str, Any]: """Return the stored persona traits and optionally recent conversation messages.""" - memory_data = {"traits": self.persona_traits.copy()} if include_messages: @@ -109,12 +104,9 @@ class PersonaMemory(BaseMemory): return {self.memory_key: memory_data} - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None: """Analyze outputs and update the persona traits and recent messages.""" - - input_text = inputs.get(self.input_key, "") output_text = outputs.get(self.output_key, "") - traits_detected = self._detect_traits(output_text) message = EnrichedMessage( @@ -126,11 +118,10 @@ class PersonaMemory(BaseMemory): self.recent_messages.append(message) - # Trim messages to maintain memory size k + # Trim recent messages to maintain memory size k if len(self.recent_messages) > self.k: self.recent_messages = self.recent_messages[-self.k :] - # Update persona traits self.persona_traits = list( {trait for msg in self.recent_messages for trait in msg.traits} )