From d6a032d8e7c111d69f4b20bdf47154330fc16b61 Mon Sep 17 00:00:00 2001 From: Mikolaj Mikuliszyn Date: Sun, 27 Apr 2025 17:48:12 +0100 Subject: [PATCH] mend --- libs/langchain/langchain/memory/persona.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/memory/persona.py b/libs/langchain/langchain/memory/persona.py index 114c49ef56c..9986ad7a572 100644 --- a/libs/langchain/langchain/memory/persona.py +++ b/libs/langchain/langchain/memory/persona.py @@ -1,12 +1,14 @@ -from typing import Any, Callable, Optional -from pydantic import BaseModel, Field import re +from typing import Any, Callable, Optional +from pydantic import BaseModel, Field from langchain_core.memory import BaseMemory class EnrichedMessage(BaseModel): - """A message enriched with persona traits and metadata.""" + """ + A message enriched with persona traits and metadata. + """ id: str content: str @@ -43,7 +45,8 @@ class PersonaMemory(BaseMemory): def _detect_traits(self, text: str) -> dict[str, int]: """ - Detect persona traits using both a default method and optionally an external engine. + Detect persona traits using both a default method and optionally + an external engine. Always guarantees a usable result, even if external services fail. """ @@ -94,7 +97,10 @@ class PersonaMemory(BaseMemory): def load_memory_variables( self, inputs: dict[str, Any], include_messages: bool = False ) -> dict[str, Any]: - """Return the stored persona traits and optionally recent conversation messages.""" + """ + Return the stored persona traits and optionally recent + conversation messages. + """ memory_data = {"traits": self.persona_traits.copy()} if include_messages: @@ -105,7 +111,9 @@ class PersonaMemory(BaseMemory): return {self.memory_key: memory_data} def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None: - """Analyze outputs and update the persona traits and recent messages.""" + """ + Analyze outputs and update the persona traits and recent messages. + """ output_text = outputs.get(self.output_key, "") traits_detected = self._detect_traits(output_text)