mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
mend
This commit is contained in:
parent
a982d2f27d
commit
21c491a022
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -10,8 +10,8 @@ class EnrichedMessage(BaseModel):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
traits: List[str] = Field(default_factory=list)
|
traits: list[str] = Field(default_factory=list)
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PersonaMemory(BaseMemory):
|
class PersonaMemory(BaseMemory):
|
||||||
@ -28,12 +28,12 @@ class PersonaMemory(BaseMemory):
|
|||||||
input_key: str = "input"
|
input_key: str = "input"
|
||||||
output_key: str = "output"
|
output_key: str = "output"
|
||||||
k: int = 10
|
k: int = 10
|
||||||
trait_detection_engine: Optional[Callable[[str], Dict[str, int]]] = None
|
trait_detection_engine: Optional[Callable[[str], dict[str, int]]] = None
|
||||||
persona_traits: List[str] = Field(default_factory=list)
|
persona_traits: list[str] = Field(default_factory=list)
|
||||||
recent_messages: List[Any] = Field(default_factory=list)
|
recent_messages: list[Any] = Field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> list[str]:
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
@ -41,13 +41,12 @@ class PersonaMemory(BaseMemory):
|
|||||||
self.persona_traits = []
|
self.persona_traits = []
|
||||||
self.recent_messages = []
|
self.recent_messages = []
|
||||||
|
|
||||||
def _detect_traits(self, text: str) -> Dict[str, int]:
|
def _detect_traits(self, text: str) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Detect persona traits using both a default method and optionally an external engine.
|
Detect persona traits using both a default method and optionally an external engine.
|
||||||
|
|
||||||
Always guarantees a usable result, even if external services fail. (Just lower quality as hard-coded)
|
Always guarantees a usable result, even if external services fail.
|
||||||
"""
|
"""
|
||||||
# Always run the fast, local simple detection first
|
|
||||||
trait_patterns = {
|
trait_patterns = {
|
||||||
"apologetic": ["sorry", "apologize", "apologies", "I apologize"],
|
"apologetic": ["sorry", "apologize", "apologies", "I apologize"],
|
||||||
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
"enthusiastic": ["!", "awesome", "great job", "fantastic"],
|
||||||
@ -59,10 +58,10 @@ class PersonaMemory(BaseMemory):
|
|||||||
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
|
"analytical": ["analytical", "analyze", "analysis", "logical", "reasoning"],
|
||||||
}
|
}
|
||||||
|
|
||||||
trait_hits: Dict[str, int] = {}
|
trait_hits: dict[str, int] = {}
|
||||||
lowered_text = text.lower()
|
lowered_text = text.lower()
|
||||||
|
|
||||||
# Clean the text for word matching by removing punctuation
|
# Clean text for word matching
|
||||||
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
|
clean_text = re.sub(r"[^\w\s]", " ", lowered_text)
|
||||||
words = clean_text.split()
|
words = clean_text.split()
|
||||||
|
|
||||||
@ -79,27 +78,23 @@ class PersonaMemory(BaseMemory):
|
|||||||
else:
|
else:
|
||||||
if pattern_lower in lowered_text:
|
if pattern_lower in lowered_text:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
trait_hits[trait] = count
|
trait_hits[trait] = count
|
||||||
|
|
||||||
# Now attempt external engine if available
|
|
||||||
if self.trait_detection_engine:
|
if self.trait_detection_engine:
|
||||||
try:
|
try:
|
||||||
external_traits = self.trait_detection_engine(text)
|
external_traits = self.trait_detection_engine(text)
|
||||||
if isinstance(external_traits, dict):
|
if isinstance(external_traits, dict):
|
||||||
return external_traits
|
return external_traits
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Silently fall back to default detection
|
pass # Fallback to default detection
|
||||||
|
|
||||||
# Fall back to simple default detection if external fails or is unavailable
|
|
||||||
return trait_hits
|
return trait_hits
|
||||||
|
|
||||||
def load_memory_variables(
|
def load_memory_variables(
|
||||||
self, inputs: Dict[str, Any], include_messages: bool = False
|
self, inputs: dict[str, Any], include_messages: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return the stored persona traits and optionally recent conversation messages."""
|
"""Return the stored persona traits and optionally recent conversation messages."""
|
||||||
|
|
||||||
memory_data = {"traits": self.persona_traits.copy()}
|
memory_data = {"traits": self.persona_traits.copy()}
|
||||||
|
|
||||||
if include_messages:
|
if include_messages:
|
||||||
@ -109,12 +104,9 @@ class PersonaMemory(BaseMemory):
|
|||||||
|
|
||||||
return {self.memory_key: memory_data}
|
return {self.memory_key: memory_data}
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||||
"""Analyze outputs and update the persona traits and recent messages."""
|
"""Analyze outputs and update the persona traits and recent messages."""
|
||||||
|
|
||||||
input_text = inputs.get(self.input_key, "")
|
|
||||||
output_text = outputs.get(self.output_key, "")
|
output_text = outputs.get(self.output_key, "")
|
||||||
|
|
||||||
traits_detected = self._detect_traits(output_text)
|
traits_detected = self._detect_traits(output_text)
|
||||||
|
|
||||||
message = EnrichedMessage(
|
message = EnrichedMessage(
|
||||||
@ -126,11 +118,10 @@ class PersonaMemory(BaseMemory):
|
|||||||
|
|
||||||
self.recent_messages.append(message)
|
self.recent_messages.append(message)
|
||||||
|
|
||||||
# Trim messages to maintain memory size k
|
# Trim recent messages to maintain memory size k
|
||||||
if len(self.recent_messages) > self.k:
|
if len(self.recent_messages) > self.k:
|
||||||
self.recent_messages = self.recent_messages[-self.k :]
|
self.recent_messages = self.recent_messages[-self.k :]
|
||||||
|
|
||||||
# Update persona traits
|
|
||||||
self.persona_traits = list(
|
self.persona_traits = list(
|
||||||
{trait for msg in self.recent_messages for trait in msg.traits}
|
{trait for msg in self.recent_messages for trait in msg.traits}
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user