Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
cc262b9d01 rfc 2023-08-24 16:23:46 -07:00
2 changed files with 38 additions and 2 deletions

View File

@@ -321,7 +321,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@@ -1,10 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from langchain.load.serializable import Serializable
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
Runnable,
)
class BaseMemory(Serializable, ABC):
@@ -59,6 +64,37 @@ class BaseMemory(Serializable, ABC):
def clear(self) -> None:
"""Clear memory contents."""
def as_runnables(self, input_key: str = "input") -> Tuple[Runnable, Runnable]:
return self.as_load_runnable(input_key), self.as_save_runnable(input_key)
def _load_memory_variables_return_inputs(
self, inputs: Dict[str, Any]
) -> Dict[str, Any]:
mem_vars = self.load_memory_variables(inputs)
return {**inputs, **mem_vars}
def as_load_runnable(
self, input_key: str
) -> Runnable[Dict[str, Any], Dict[str, Any]]:
""""""
return PutLocalVar(input_key) | self._load_memory_variables_return_inputs
def as_save_runnable(
self, input_key: str
) -> Runnable[Dict[str, Any], Dict[str, Any]]:
""""""
def _runnable_save(input_output: Dict[str, Any]) -> Dict[str, Any]:
input = input_output[input_key]
if isinstance(input_output["output"], dict):
output = input_output["output"]
else:
output = {"output": input_output["output"]}
self.save_context(input, output)
return output
return GetLocalVar(input_key, passthrough_key="output") | _runnable_save
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.