Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
5d62621b99 rfc 2023-08-28 13:47:47 -07:00
2 changed files with 110 additions and 4 deletions

View File

@@ -1,10 +1,20 @@
from __future__ import annotations
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain.load.serializable import Serializable
from langchain.schema.storage import BaseStore
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
Runnable,
RunnableConfig,
RunnableLambda,
RunnablePassthrough,
)
class BaseMemory(Serializable, ABC):
@@ -60,6 +70,97 @@ class BaseMemory(Serializable, ABC):
"""Clear memory contents."""
class BaseMemorySessionManager(Runnable[str, BaseMemory]):
""""""
def __init__(
self,
store: BaseStore[str, BaseMemory],
*,
default: Optional[BaseMemory] = None,
default_factory: Optional[Callable[[], BaseMemory]] = None,
) -> None:
self._store = store
self._default = default
self._default_factory = default_factory
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> BaseMemory:
return self.mset_default((input))[0]
def batch(
self,
inputs: List[str],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[BaseMemory]:
return self.mset_default(inputs)
def mset_default(self, keys: Sequence[str]) -> List[BaseMemory]:
vals = self._store.mget(keys)
missing_idxs = [i for i, v in enumerate(vals) if v is None]
for i in missing_idxs:
vals[i] = self.default_factory()
self._store.mset(((keys[i], vals[i]) for i in missing_idxs))
return vals
def default_factory(self) -> BaseMemory:
return self._default or self._default_factory()
def _get_session_id_default(input: Any) -> str:
if isinstance(input, dict) and "session_id" in input:
return input["session_id"]
return str(uuid.uuid4())
def _pop_session_id(input: Any) -> str:
if isinstance(input, dict) and "session_id" in input:
input.pop("session_id")
return input
def _get_session_id(input: Any) -> str:
return input["session_id"]
def mem_loader(session_manager: "BaseMemorySessionManager", input_key: str) -> Runnable:
def _load_fn(x: Dict) -> Dict[str, Any]:
inputs = x[input_key]
mem = x["memory"]
return {**mem.load_memory_variables(inputs), **inputs}
return (
{"session_id": _get_session_id_default, input_key: _pop_session_id}
| PutLocalVar(("session_id", input_key))
| _get_session_id
| session_manager
| GetLocalVar(input_key, passthrough_key="memory")
| _load_fn
)
def mem_saver(session_manager: "BaseMemorySessionManager", input_key: str) -> Runnable:
def _runnable_save(x: Dict[str, Any]) -> Dict[str, Any]:
mem = x["memory"]
input = x[input_key]
if isinstance(x["output"], dict):
output = x["output"]
else:
output = {"output": x["output"]}
mem.save_context(input, output)
return x["output"]
return (
{
input_key: GetLocalVar(input_key),
"memory": GetLocalVar("session_id") | session_manager,
"output": RunnablePassthrough(),
}
| RunnableLambda(_runnable_save)
| {"output": RunnablePassthrough(), "session_id": GetLocalVar("session_id")}
)
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.

View File

@@ -8,6 +8,7 @@ from typing import (
Iterator,
Mapping,
Optional,
Sequence,
Union,
)
@@ -24,7 +25,7 @@ if TYPE_CHECKING:
class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]]
key: Union[str, Sequence[str], Mapping[str, str]]
"""The key(s) to use for storing the input variable(s) in local state.
If a string is provided then the entire input is stored under that key. If a
@@ -52,13 +53,17 @@ class PutLocalVar(RunnablePassthrough):
config["_locals"][self.key] = input
else:
config["_locals"][self.key] += input
elif isinstance(self.key, Mapping):
elif isinstance(self.key, (Sequence, Mapping)):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
if not isinstance(self.key, Mapping):
key_map = {key: key for key in self.key}
else:
key_map = self.key
for input_key, put_key in key_map.items():
if put_key not in config["_locals"] or replace:
config["_locals"][put_key] = input[input_key]
else: