mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
community[minor]: Add native async support to SQLChatMessageHistory (#22065)
# package community: Fix SQLChatMessageHistory ## Description Here is a rewrite of `SQLChatMessageHistory` to properly implement the asynchronous approach. The code circumvents [issue 22021](https://github.com/langchain-ai/langchain/issues/22021) by accepting a synchronous call to `def add_messages()` in an asynchronous scenario. This bypasses the bug. For the same reasons as in [PR 22](https://github.com/langchain-ai/langchain-postgres/pull/32) of `langchain-postgres`, we use a lazy strategy for table creation. Indeed, the promise of the constructor cannot be fulfilled without this. It is not possible to invoke a synchronous call in a constructor. We compensate for this by waiting for the next asynchronous method call to create the table. The goal of the `PostgresChatMessageHistory` class (in `langchain-postgres`) is, among other things, to be able to recycle database connections. The implementation of the class is problematic, as we have demonstrated in [issue 22021](https://github.com/langchain-ai/langchain/issues/22021). Our new implementation of `SQLChatMessageHistory` achieves this by using a singleton of type (`Async`)`Engine` for the database connection. The connection pool is managed by this singleton, and the code is then reentrant. We also accept the type `str` (optionally complemented by `async_mode`. I know you don't like this much, but it's the only way to allow an asynchronous connection string). In order to unify the different classes handling database connections, we have renamed `connection_string` to `connection`, and `Session` to `session_maker`. Now, a single transaction is used to add a list of messages. Thus, a crash during this write operation will not leave the database in an unstable state with a partially added message list. This makes the code resilient. We believe that the `PostgresChatMessageHistory` class is no longer necessary and can be replaced by: ``` PostgresChatMessageHistory = SQLChatMessageHistory ``` This also fixes the bug. ## Issue - [issue 22021](https://github.com/langchain-ai/langchain/issues/22021) - Bug in _exit_history() - Bugs in PostgresChatMessageHistory and sync usage - Bugs in PostgresChatMessageHistory and async usage - [issue 36](https://github.com/langchain-ai/langchain-postgres/issues/36) ## Twitter handle: pprados ## Tests - libs/community/tests/unit_tests/chat_message_histories/test_sql.py (add async test) @baskaryan, @eyurtsev or @hwchase17 can you check this PR ? And, I've been waiting a long time for validation from other PRs. Can you take a look? - [PR 32](https://github.com/langchain-ai/langchain-postgres/pull/32) - [PR 15575](https://github.com/langchain-ai/langchain/pull/15575) - [PR 13200](https://github.com/langchain-ai/langchain/pull/13200) --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
59bef31997
commit
8250c177de
@ -1,9 +1,22 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, create_engine
|
||||
from langchain_core._api import deprecated, warn_deprecated
|
||||
from sqlalchemy import Column, Integer, Text, delete, select
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
@ -15,7 +28,22 @@ from langchain_core.messages import (
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
Session as SQLSession,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
declarative_base,
|
||||
scoped_session,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -80,36 +108,98 @@ class DefaultMessageConverter(BaseMessageConverter):
|
||||
return self.model_class
|
||||
|
||||
|
||||
DBConnection = Union[AsyncEngine, Engine, str]
|
||||
|
||||
_warned_once_already = False
|
||||
|
||||
|
||||
class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in an SQL database."""
|
||||
|
||||
@property
|
||||
@deprecated("0.2.2", removal="0.3.0", alternative="session_maker")
|
||||
def Session(self) -> Union[scoped_session, async_sessionmaker]:
|
||||
return self.session_maker
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_string: str,
|
||||
connection_string: Optional[str] = None,
|
||||
table_name: str = "message_store",
|
||||
session_id_field_name: str = "session_id",
|
||||
custom_message_converter: Optional[BaseMessageConverter] = None,
|
||||
connection: Union[None, DBConnection] = None,
|
||||
engine_args: Optional[Dict[str, Any]] = None,
|
||||
async_mode: Optional[bool] = None, # Use only if connection is a string
|
||||
):
|
||||
self.connection_string = connection_string
|
||||
self.engine = create_engine(connection_string, echo=False)
|
||||
assert not (
|
||||
connection_string and connection
|
||||
), "connection_string and connection are mutually exclusive"
|
||||
if connection_string:
|
||||
global _warned_once_already
|
||||
if not _warned_once_already:
|
||||
warn_deprecated(
|
||||
since="0.2.2",
|
||||
removal="0.3.0",
|
||||
name="connection_string",
|
||||
alternative="Use connection instead",
|
||||
)
|
||||
_warned_once_already = True
|
||||
connection = connection_string
|
||||
self.connection_string = connection_string
|
||||
if isinstance(connection, str):
|
||||
self.async_mode = async_mode
|
||||
if async_mode:
|
||||
self.async_engine = create_async_engine(
|
||||
connection, **(engine_args or {})
|
||||
)
|
||||
else:
|
||||
self.engine = create_engine(url=connection, **(engine_args or {}))
|
||||
elif isinstance(connection, Engine):
|
||||
self.async_mode = False
|
||||
self.engine = connection
|
||||
elif isinstance(connection, AsyncEngine):
|
||||
self.async_mode = True
|
||||
self.async_engine = connection
|
||||
else:
|
||||
raise ValueError(
|
||||
"connection should be a connection string or an instance of "
|
||||
"sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine"
|
||||
)
|
||||
|
||||
# To be consistent with others SQL implementations, rename to session_maker
|
||||
self.session_maker: Union[scoped_session, async_sessionmaker]
|
||||
if self.async_mode:
|
||||
self.session_maker = async_sessionmaker(bind=self.async_engine)
|
||||
else:
|
||||
self.session_maker = scoped_session(sessionmaker(bind=self.engine))
|
||||
|
||||
self.session_id_field_name = session_id_field_name
|
||||
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
|
||||
self.sql_model_class = self.converter.get_sql_model_class()
|
||||
if not hasattr(self.sql_model_class, session_id_field_name):
|
||||
raise ValueError("SQL model class must have session_id column")
|
||||
self._create_table_if_not_exists()
|
||||
self._table_created = False
|
||||
if not self.async_mode:
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
self.session_id = session_id
|
||||
self.Session = sessionmaker(self.engine)
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
self.sql_model_class.metadata.create_all(self.engine)
|
||||
self._table_created = True
|
||||
|
||||
async def _acreate_table_if_not_exists(self) -> None:
|
||||
if not self._table_created:
|
||||
assert self.async_mode, "This method must be called with async_mode"
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(self.sql_model_class.metadata.create_all)
|
||||
self._table_created = True
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve all messages from db"""
|
||||
with self.Session() as session:
|
||||
with self._make_sync_session() as session:
|
||||
result = (
|
||||
session.query(self.sql_model_class)
|
||||
.where(
|
||||
@ -123,18 +213,105 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
messages.append(self.converter.from_sql_model(record))
|
||||
return messages
|
||||
|
||||
def get_messages(self) -> List[BaseMessage]:
|
||||
return self.messages
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
"""Retrieve all messages from db"""
|
||||
await self._acreate_table_if_not_exists()
|
||||
async with self._make_async_session() as session:
|
||||
stmt = (
|
||||
select(self.sql_model_class)
|
||||
.where(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
)
|
||||
.order_by(self.sql_model_class.id.asc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
messages = []
|
||||
for record in result.scalars():
|
||||
messages.append(self.converter.from_sql_model(record))
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
with self._make_sync_session() as session:
|
||||
session.add(self.converter.to_sql_model(message, self.session_id))
|
||||
session.commit()
|
||||
|
||||
async def aadd_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the store.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
await self._acreate_table_if_not_exists()
|
||||
async with self._make_async_session() as session:
|
||||
session.add(self.converter.to_sql_model(message, self.session_id))
|
||||
await session.commit()
|
||||
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
# The method RunnableWithMessageHistory._exit_history() call
|
||||
# add_message method by mistake and not aadd_message.
|
||||
# See https://github.com/langchain-ai/langchain/issues/22021
|
||||
if self.async_mode:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.aadd_messages(messages))
|
||||
else:
|
||||
with self._make_sync_session() as session:
|
||||
for message in messages:
|
||||
session.add(self.converter.to_sql_model(message, self.session_id))
|
||||
session.commit()
|
||||
|
||||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
# Add all messages in one transaction
|
||||
await self._acreate_table_if_not_exists()
|
||||
async with self.session_maker() as session:
|
||||
for message in messages:
|
||||
session.add(self.converter.to_sql_model(message, self.session_id))
|
||||
await session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from db"""
|
||||
|
||||
with self.Session() as session:
|
||||
with self._make_sync_session() as session:
|
||||
session.query(self.sql_model_class).filter(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear session memory from db"""
|
||||
|
||||
await self._acreate_table_if_not_exists()
|
||||
async with self._make_async_session() as session:
|
||||
stmt = delete(self.sql_model_class).filter(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_sync_session(self) -> Generator[SQLSession, None, None]:
|
||||
"""Make an async session."""
|
||||
if self.async_mode:
|
||||
raise ValueError(
|
||||
"Attempting to use a sync method in when async mode is turned on. "
|
||||
"Please use the corresponding async method instead."
|
||||
)
|
||||
with self.session_maker() as session:
|
||||
yield cast(SQLSession, session)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Make an async session."""
|
||||
if not self.async_mode:
|
||||
raise ValueError(
|
||||
"Attempting to use an async method in when sync mode is turned on. "
|
||||
"Please use the corresponding async method instead."
|
||||
)
|
||||
async with self.session_maker() as session:
|
||||
yield cast(AsyncSession, session)
|
||||
|
21
libs/community/poetry.lock
generated
21
libs/community/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@ -3475,6 +3475,7 @@ files = [
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||
@ -3985,7 +3986,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -4026,7 +4027,7 @@ url = "../langchain"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.3"
|
||||
version = "0.2.4"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -4035,7 +4036,7 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = "^0.1.65"
|
||||
langsmith = "^0.1.66"
|
||||
packaging = "^23.2"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
@ -4050,7 +4051,7 @@ url = "../core"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
description = "LangChain text splitting utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -6123,8 +6124,6 @@ files = [
|
||||
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
||||
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
||||
@ -6167,7 +6166,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||
@ -6176,8 +6174,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||
@ -7175,7 +7171,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -10217,9 +10212,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[extras]
|
||||
cli = ["typer"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
|
||||
extended-testing = ["aiosqlite", "aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "azure-identity", "azure-search-documents", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpathlib", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "oracledb", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "simsimd", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "77ccc0105fabe1735497289125bb276822101a6a9b1c2b596bf49b8f30b8068d"
|
||||
content-hash = "22bdadbd8a34235ba0cd923d9b380d362caa64f000053a5f91f9d163e8b41aad"
|
||||
|
@ -291,6 +291,7 @@ extended_testing = [
|
||||
"pyjwt",
|
||||
"oracledb",
|
||||
"simsimd",
|
||||
"aiosqlite"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1,8 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Tuple
|
||||
from typing import Any, AsyncGenerator, Generator, List, Tuple
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from sqlalchemy import Column, Integer, Text
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
@ -17,16 +17,23 @@ def con_str(tmp_path: Path) -> str:
|
||||
return con_str
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def acon_str(tmp_path: Path) -> str:
|
||||
file_path = tmp_path / "adb.sqlite3"
|
||||
con_str = f"sqlite+aiosqlite:///{file_path}"
|
||||
return con_str
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sql_histories(
|
||||
con_str: str,
|
||||
) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]:
|
||||
message_history = SQLChatMessageHistory(
|
||||
session_id="123", connection_string=con_str, table_name="test_table"
|
||||
session_id="123", connection=con_str, table_name="test_table"
|
||||
)
|
||||
# Create history for other session
|
||||
other_history = SQLChatMessageHistory(
|
||||
session_id="456", connection_string=con_str, table_name="test_table"
|
||||
session_id="456", connection=con_str, table_name="test_table"
|
||||
)
|
||||
|
||||
yield message_history, other_history
|
||||
@ -34,12 +41,38 @@ def sql_histories(
|
||||
other_history.clear()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def asql_histories(
|
||||
acon_str: str,
|
||||
) -> AsyncGenerator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None]:
|
||||
message_history = SQLChatMessageHistory(
|
||||
session_id="123",
|
||||
connection=acon_str,
|
||||
table_name="test_table",
|
||||
async_mode=True,
|
||||
engine_args={"echo": False},
|
||||
)
|
||||
# Create history for other session
|
||||
other_history = SQLChatMessageHistory(
|
||||
session_id="456",
|
||||
connection=acon_str,
|
||||
table_name="test_table",
|
||||
async_mode=True,
|
||||
engine_args={"echo": False},
|
||||
)
|
||||
|
||||
yield message_history, other_history
|
||||
await message_history.aclear()
|
||||
await other_history.aclear()
|
||||
|
||||
|
||||
def test_add_messages(
|
||||
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
|
||||
) -> None:
|
||||
sql_history, other_history = sql_histories
|
||||
sql_history.add_user_message("Hello!")
|
||||
sql_history.add_ai_message("Hi there!")
|
||||
sql_history.add_messages(
|
||||
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
|
||||
)
|
||||
|
||||
messages = sql_history.messages
|
||||
assert len(messages) == 2
|
||||
@ -49,39 +82,94 @@ def test_add_messages(
|
||||
assert messages[1].content == "Hi there!"
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_add_messages(
|
||||
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
|
||||
) -> None:
|
||||
sql_history, other_history = asql_histories
|
||||
await sql_history.aadd_messages(
|
||||
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
|
||||
)
|
||||
|
||||
messages = await sql_history.aget_messages()
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
assert messages[0].content == "Hello!"
|
||||
assert messages[1].content == "Hi there!"
|
||||
|
||||
|
||||
def test_multiple_sessions(
|
||||
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
|
||||
) -> None:
|
||||
sql_history, other_history = sql_histories
|
||||
sql_history.add_user_message("Hello!")
|
||||
sql_history.add_ai_message("Hi there!")
|
||||
sql_history.add_user_message("Whats cracking?")
|
||||
sql_history.add_messages(
|
||||
[
|
||||
HumanMessage(content="Hello!"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(content="Whats cracking?"),
|
||||
]
|
||||
)
|
||||
|
||||
# Ensure the messages are added correctly in the first session
|
||||
assert len(sql_history.messages) == 3, "waat"
|
||||
assert sql_history.messages[0].content == "Hello!"
|
||||
assert sql_history.messages[1].content == "Hi there!"
|
||||
assert sql_history.messages[2].content == "Whats cracking?"
|
||||
messages = sql_history.messages
|
||||
assert len(messages) == 3, "waat"
|
||||
assert messages[0].content == "Hello!"
|
||||
assert messages[1].content == "Hi there!"
|
||||
assert messages[2].content == "Whats cracking?"
|
||||
|
||||
# second session
|
||||
other_history.add_user_message("Hellox")
|
||||
other_history.add_messages([HumanMessage(content="Hellox")])
|
||||
assert len(other_history.messages) == 1
|
||||
assert len(sql_history.messages) == 3
|
||||
messages = sql_history.messages
|
||||
assert len(messages) == 3
|
||||
assert other_history.messages[0].content == "Hellox"
|
||||
assert sql_history.messages[0].content == "Hello!"
|
||||
assert sql_history.messages[1].content == "Hi there!"
|
||||
assert sql_history.messages[2].content == "Whats cracking?"
|
||||
assert messages[0].content == "Hello!"
|
||||
assert messages[1].content == "Hi there!"
|
||||
assert messages[2].content == "Whats cracking?"
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_multiple_sessions(
|
||||
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
|
||||
) -> None:
|
||||
sql_history, other_history = asql_histories
|
||||
await sql_history.aadd_messages(
|
||||
[
|
||||
HumanMessage(content="Hello!"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(content="Whats cracking?"),
|
||||
]
|
||||
)
|
||||
|
||||
# Ensure the messages are added correctly in the first session
|
||||
messages: List[BaseMessage] = await sql_history.aget_messages()
|
||||
assert len(messages) == 3, "waat"
|
||||
assert messages[0].content == "Hello!"
|
||||
assert messages[1].content == "Hi there!"
|
||||
assert messages[2].content == "Whats cracking?"
|
||||
|
||||
# second session
|
||||
await other_history.aadd_messages([HumanMessage(content="Hellox")])
|
||||
messages = await sql_history.aget_messages()
|
||||
assert len(await other_history.aget_messages()) == 1
|
||||
assert len(messages) == 3
|
||||
assert (await other_history.aget_messages())[0].content == "Hellox"
|
||||
assert messages[0].content == "Hello!"
|
||||
assert messages[1].content == "Hi there!"
|
||||
assert messages[2].content == "Whats cracking?"
|
||||
|
||||
|
||||
def test_clear_messages(
|
||||
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
|
||||
) -> None:
|
||||
sql_history, other_history = sql_histories
|
||||
sql_history.add_user_message("Hello!")
|
||||
sql_history.add_ai_message("Hi there!")
|
||||
sql_history.add_messages(
|
||||
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
|
||||
)
|
||||
assert len(sql_history.messages) == 2
|
||||
# Now create another history with different session id
|
||||
other_history.add_user_message("Hellox")
|
||||
other_history.add_messages([HumanMessage(content="Hellox")])
|
||||
assert len(other_history.messages) == 1
|
||||
assert len(sql_history.messages) == 2
|
||||
# Now clear the first history
|
||||
@ -90,6 +178,25 @@ def test_clear_messages(
|
||||
assert len(other_history.messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_clear_messages(
|
||||
asql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory],
|
||||
) -> None:
|
||||
sql_history, other_history = asql_histories
|
||||
await sql_history.aadd_messages(
|
||||
[HumanMessage(content="Hello!"), AIMessage(content="Hi there!")]
|
||||
)
|
||||
assert len(await sql_history.aget_messages()) == 2
|
||||
# Now create another history with different session id
|
||||
await other_history.aadd_messages([HumanMessage(content="Hellox")])
|
||||
assert len(await other_history.aget_messages()) == 1
|
||||
assert len(await sql_history.aget_messages()) == 2
|
||||
# Now clear the first history
|
||||
await sql_history.aclear()
|
||||
assert len(await sql_history.aget_messages()) == 0
|
||||
assert len(await other_history.aget_messages()) == 1
|
||||
|
||||
|
||||
def test_model_no_session_id_field_error(con_str: str) -> None:
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user