Compare commits

...

3 Commits

Author SHA1 Message Date
Eugene Yurtsev
382cff67fb Merge branch 'master' into eugene/bug_history 2023-12-08 13:15:35 -05:00
Eugene Yurtsev
a7be0aa290 x 2023-12-08 13:14:46 -05:00
Eugene Yurtsev
d3379c58be x 2023-12-08 12:45:41 -05:00

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import BaseModel
@@ -17,8 +17,10 @@ def test_interfaces() -> None:
assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2"
def _get_get_session_history() -> Callable[..., ChatMessageHistory]:
chat_history_store = {}
def _get_get_session_history(
*, store: Optional[Dict[str, Any]] = None
) -> Callable[..., ChatMessageHistory]:
chat_history_store = store if store is not None else {}
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
if session_id not in chat_history_store:
@@ -33,13 +35,24 @@ def test_input_messages() -> None:
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
get_session_history = _get_get_session_history()
store = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}}
output = with_history.invoke([HumanMessage(content="hello")], config)
assert output == "you said: hello"
output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
assert store == {
"1": ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="good bye"),
AIMessage(content="you said: hello\ngood bye"),
]
)
}
def test_input_dict() -> None: