mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
mend
This commit is contained in:
parent
a982d2f27d
commit
21c491a022
@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
|
||||
@ -10,8 +10,8 @@ class EnrichedMessage(BaseModel):
|
||||
|
||||
id: str
|
||||
content: str
|
||||
traits: List[str] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
traits: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PersonaMemory(BaseMemory):
|
||||
@ -28,12 +28,12 @@ class PersonaMemory(BaseMemory):
|
||||
input_key: str = "input"
|
||||
output_key: str = "output"
|
||||
k: int = 10
|
||||
trait_detection_engine: Optional[Callable[[str], Dict[str, int]]] = None
|
||||
persona_traits: List[str] = Field(default_factory=list)
|
||||
recent_messages: List[Any] = Field(default_factory=list)
|
||||
trait_detection_engine: Optional[Callable[[str], dict[str, int]]] = None
|
||||
persona_traits: list[str] = Field(default_factory=list)
|
||||
recent_messages: list[Any] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
def memory_variables(self) -> list[str]:
|
||||
return [self.memory_key]
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -41,13 +41,12 @@ class PersonaMemory(BaseMemory):
|
||||
self.persona_traits = []
|
||||
self.recent_messages = []
|
||||
|
||||
def _detect_traits(self, text: str) -> Dict[str, int]:
|
||||
def _detect_traits(self, text: str) -> dict[str, int]:
|
||||
"""
|
||||
Detect persona traits using both a default method and optionally an external engine.
|
||||
|
||||
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
|
||||
Always guarantees a usable result, even if external services fail.
|
||||
"""
|
||||
# Always run the fast, local simple detection first
|
||||
trait_patterns = {
|
||||
"apologetic": ["sorry", "apologize", "apologies", "I apologize"],
|
||||
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
||||
@ -59,10 +58,10 @@ class PersonaMemory(BaseMemory):
|
||||
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
|
||||
}
|
||||
|
||||
trait_hits: Dict[str, int] = {}
|
||||
trait_hits: dict[str, int] = {}
|
||||
lowered_text = text.lower()
|
||||
|
||||
# Clean the text for word matching by removing punctuation
|
||||
# Clean text for word matching
|
||||
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
|
||||
words = clean_text.split()
|
||||
|
||||
@ -79,27 +78,23 @@ class PersonaMemory(BaseMemory):
|
||||
else:
|
||||
if pattern_lower in lowered_text:
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
trait_hits[trait] = count
|
||||
|
||||
# Now attempt external engine if available
|
||||
if self.trait_detection_engine:
|
||||
try:
|
||||
external_traits = self.trait_detection_engine(text)
|
||||
if isinstance(external_traits, dict):
|
||||
return external_traits
|
||||
except Exception:
|
||||
pass # Silently fall back to default detection
|
||||
pass # Fallback to default detection
|
||||
|
||||
# Fall back to simple default detection if external fails or is unavailable
|
||||
return trait_hits
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any], include_messages: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
self, inputs: dict[str, Any], include_messages: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Return the stored persona traits and optionally recent conversation messages."""
|
||||
|
||||
memory_data = {"traits": self.persona_traits.copy()}
|
||||
|
||||
if include_messages:
|
||||
@ -109,12 +104,9 @@ class PersonaMemory(BaseMemory):
|
||||
|
||||
return {self.memory_key: memory_data}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Analyze outputs and update the persona traits and recent messages."""
|
||||
|
||||
input_text = inputs.get(self.input_key, "")
|
||||
output_text = outputs.get(self.output_key, "")
|
||||
|
||||
traits_detected = self._detect_traits(output_text)
|
||||
|
||||
message = EnrichedMessage(
|
||||
@ -126,11 +118,10 @@ class PersonaMemory(BaseMemory):
|
||||
|
||||
self.recent_messages.append(message)
|
||||
|
||||
# Trim messages to maintain memory size k
|
||||
# Trim recent messages to maintain memory size k
|
||||
if len(self.recent_messages) > self.k:
|
||||
self.recent_messages = self.recent_messages[-self.k :]
|
||||
|
||||
# Update persona traits
|
||||
self.persona_traits = list(
|
||||
{trait for msg in self.recent_messages for trait in msg.traits}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user