mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
postgres[minor]: add postgres checkpoint implementation (#20025)
Adds checkpoint implementation using psycopg
This commit is contained in:
parent
039b7a472d
commit
ba9e0d76c1
@ -63,3 +63,61 @@ chat_history.add_messages([
|
|||||||
|
|
||||||
print(chat_history.messages)
|
print(chat_history.messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### PostgresCheckpoint
|
||||||
|
|
||||||
|
An implementation of the `Checkpoint` abstraction in LangGraph using Postgres.
|
||||||
|
|
||||||
|
|
||||||
|
Async Usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from langchain_postgres import (
|
||||||
|
PostgresCheckpoint, PickleCheckpointSerializer
|
||||||
|
)
|
||||||
|
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
# Example configuration
|
||||||
|
conninfo="postgresql://user:password@localhost:5432/dbname",
|
||||||
|
max_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uses the pickle module for serialization
|
||||||
|
# Make sure that you're only de-serializing trusted data
|
||||||
|
# (e.g., payloads that you have serialized yourself).
|
||||||
|
# Or implement a custom serializer.
|
||||||
|
checkpoint = PostgresCheckpoint(
|
||||||
|
serializer=PickleCheckpointSerializer(),
|
||||||
|
async_connection=pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the checkpoint object to put, get, list checkpoints, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
Sync Usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from psycopg_pool import ConnectionPool
|
||||||
|
from langchain_postgres import (
|
||||||
|
PostgresCheckpoint, PickleCheckpointSerializer
|
||||||
|
)
|
||||||
|
|
||||||
|
pool = ConnectionPool(
|
||||||
|
# Example configuration
|
||||||
|
conninfo="postgresql://user:password@localhost:5432/dbname",
|
||||||
|
max_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uses the pickle module for serialization
|
||||||
|
# Make sure that you're only de-serializing trusted data
|
||||||
|
# (e.g., payloads that you have serialized yourself).
|
||||||
|
# Or implement a custom serializer.
|
||||||
|
checkpoint = PostgresCheckpoint(
|
||||||
|
serializer=PickleCheckpointSerializer(),
|
||||||
|
sync_connection=pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the checkpoint object to put, get, list checkpoints, etc.
|
||||||
|
```
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
|
from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
|
||||||
|
from langchain_postgres.checkpoint import (
|
||||||
|
CheckpointSerializer,
|
||||||
|
PickleCheckpointSerializer,
|
||||||
|
PostgresCheckpoint,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
__version__ = metadata.version(__package__)
|
__version__ = metadata.version(__package__)
|
||||||
@ -10,5 +15,8 @@ except metadata.PackageNotFoundError:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"__version__",
|
"__version__",
|
||||||
|
"CheckpointSerializer",
|
||||||
"PostgresChatMessageHistory",
|
"PostgresChatMessageHistory",
|
||||||
|
"PostgresCheckpoint",
|
||||||
|
"PickleCheckpointSerializer",
|
||||||
]
|
]
|
||||||
|
565
libs/partners/postgres/langchain_postgres/checkpoint.py
Normal file
565
libs/partners/postgres/langchain_postgres/checkpoint.py
Normal file
@ -0,0 +1,565 @@
|
|||||||
|
"""Implementation of a langgraph checkpoint saver using Postgres."""
|
||||||
|
import abc
|
||||||
|
import pickle
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
from typing import AsyncGenerator, AsyncIterator, Generator, Optional, Union, cast
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
|
||||||
|
from langgraph.checkpoint import BaseCheckpointSaver
|
||||||
|
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple
|
||||||
|
from psycopg_pool import AsyncConnectionPool, ConnectionPool
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointSerializer(abc.ABC):
|
||||||
|
"""A serializer for serializing and deserializing objects to and from bytes."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def dumps(self, obj: Checkpoint) -> bytes:
|
||||||
|
"""Serialize an object to bytes."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def loads(self, data: bytes) -> Checkpoint:
|
||||||
|
"""Deserialize an object from bytes."""
|
||||||
|
|
||||||
|
|
||||||
|
class PickleCheckpointSerializer(CheckpointSerializer):
|
||||||
|
"""Use the pickle module to serialize and deserialize objects.
|
||||||
|
|
||||||
|
This serializer uses the pickle module to serialize and deserialize objects.
|
||||||
|
|
||||||
|
While pickling can serialize a wide range of Python objects, it may fail
|
||||||
|
de-serializable objects upon updates of the Python version or the python
|
||||||
|
environment (e.g., the object's class definition changes in LangGraph).
|
||||||
|
|
||||||
|
*Security Warning*: The pickle module can deserialize malicious payloads,
|
||||||
|
only use this serializer with trusted data; e.g., data that you
|
||||||
|
have serialized yourself and can guarantee the integrity of.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def dumps(self, obj: Checkpoint) -> bytes:
|
||||||
|
"""Serialize an object to bytes."""
|
||||||
|
return pickle.dumps(obj)
|
||||||
|
|
||||||
|
def loads(self, data: bytes) -> Checkpoint:
|
||||||
|
"""Deserialize an object from bytes."""
|
||||||
|
return cast(Checkpoint, pickle.loads(data))
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresCheckpoint(BaseCheckpointSaver):
|
||||||
|
"""LangGraph checkpoint saver for Postgres.
|
||||||
|
|
||||||
|
This implementation of a checkpoint saver uses a Postgres database to save
|
||||||
|
and retrieve checkpoints. It uses the psycopg3 package to interact with the
|
||||||
|
Postgres database.
|
||||||
|
|
||||||
|
The checkpoint accepts either a sync_connection in the form of a psycopg.Connection
|
||||||
|
or a psycopg.ConnectionPool object, or an async_connection in the form of a
|
||||||
|
psycopg.AsyncConnection or psycopg.AsyncConnectionPool object.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
1. First time use: create schema in the database using the `create_schema` method or
|
||||||
|
the async version `acreate_schema` method.
|
||||||
|
2. Create a PostgresCheckpoint object with a serializer and an appropriate
|
||||||
|
connection object.
|
||||||
|
It's recommended to use a connection pool object for the connection.
|
||||||
|
If using a connection object, you are responsible for closing the connection
|
||||||
|
when done.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
|
||||||
|
Sync usage with a connection pool:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from psycopg_pool import ConnectionPool
|
||||||
|
from langchain_postgres import (
|
||||||
|
PostgresCheckpoint, PickleCheckpointSerializer
|
||||||
|
)
|
||||||
|
|
||||||
|
pool = ConnectionPool(
|
||||||
|
# Example configuration
|
||||||
|
conninfo="postgresql://user:password@localhost:5432/dbname",
|
||||||
|
max_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uses the pickle module for serialization
|
||||||
|
# Make sure that you're only de-serializing trusted data
|
||||||
|
# (e.g., payloads that you have serialized yourself).
|
||||||
|
# Or implement a custom serializer.
|
||||||
|
checkpoint = PostgresCheckpoint(
|
||||||
|
serializer=PickleCheckpointSerializer(),
|
||||||
|
sync_connection=pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the checkpoint object to put, get, list checkpoints, etc.
|
||||||
|
|
||||||
|
|
||||||
|
Async usage with a connection pool:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from langchain_postgres import (
|
||||||
|
PostgresCheckpoint, PickleCheckpointSerializer
|
||||||
|
)
|
||||||
|
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
# Example configuration
|
||||||
|
conninfo="postgresql://user:password@localhost:5432/dbname",
|
||||||
|
max_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uses the pickle module for serialization
|
||||||
|
# Make sure that you're only de-serializing trusted data
|
||||||
|
# (e.g., payloads that you have serialized yourself).
|
||||||
|
# Or implement a custom serializer.
|
||||||
|
checkpoint = PostgresCheckpoint(
|
||||||
|
serializer=PickleCheckpointSerializer(),
|
||||||
|
async_connection=pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the checkpoint object to put, get, list checkpoints, etc.
|
||||||
|
|
||||||
|
|
||||||
|
Async usage with a connection object:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
from langchain_postgres import (
|
||||||
|
PostgresCheckpoint, PickleCheckpointSerializer
|
||||||
|
)
|
||||||
|
|
||||||
|
conninfo="postgresql://user:password@localhost:5432/dbname"
|
||||||
|
# Take care of closing the connection when done
|
||||||
|
async with AsyncConnection(conninfo=conninfo) as conn:
|
||||||
|
# Uses the pickle module for serialization
|
||||||
|
# Make sure that you're only de-serializing trusted data
|
||||||
|
# (e.g., payloads that you have serialized yourself).
|
||||||
|
# Or implement a custom serializer.
|
||||||
|
checkpoint = PostgresCheckpoint(
|
||||||
|
serializer=PickleCheckpointSerializer(),
|
||||||
|
async_connection=conn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the checkpoint object to put, get, list checkpoints, etc.
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
serializer: CheckpointSerializer
|
||||||
|
"""The serializer for serializing and deserializing objects to and from bytes."""
|
||||||
|
|
||||||
|
sync_connection: Optional[Union[psycopg.Connection, ConnectionPool]] = None
|
||||||
|
"""The synchronous connection or pool to the Postgres database.
|
||||||
|
|
||||||
|
If providing a connection object, please ensure that the connection is open
|
||||||
|
and remember to close the connection when done.
|
||||||
|
"""
|
||||||
|
async_connection: Optional[
|
||||||
|
Union[psycopg.AsyncConnection, AsyncConnectionPool]
|
||||||
|
] = None
|
||||||
|
"""The asynchronous connection or pool to the Postgres database.
|
||||||
|
|
||||||
|
If providing a connection object, please ensure that the connection is open
|
||||||
|
and remember to close the connection when done.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||||
|
"""Return the configuration specs for this runnable."""
|
||||||
|
return [
|
||||||
|
ConfigurableFieldSpec(
|
||||||
|
id="thread_id",
|
||||||
|
annotation=Optional[str],
|
||||||
|
name="Thread ID",
|
||||||
|
description=None,
|
||||||
|
default=None,
|
||||||
|
is_shared=True,
|
||||||
|
),
|
||||||
|
CheckpointThreadTs,
|
||||||
|
]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _get_sync_connection(self) -> Generator[psycopg.Connection, None, None]:
|
||||||
|
"""Get the connection to the Postgres database."""
|
||||||
|
if isinstance(self.sync_connection, psycopg.Connection):
|
||||||
|
yield self.sync_connection
|
||||||
|
elif isinstance(self.sync_connection, ConnectionPool):
|
||||||
|
with self.sync_connection.connection() as conn:
|
||||||
|
yield conn
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid sync connection object. Please initialize the check pointer "
|
||||||
|
f"with an appropriate sync connection object. "
|
||||||
|
f"Got {type(self.sync_connection)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _get_async_connection(
|
||||||
|
self,
|
||||||
|
) -> AsyncGenerator[psycopg.AsyncConnection, None]:
|
||||||
|
"""Get the connection to the Postgres database."""
|
||||||
|
if isinstance(self.async_connection, psycopg.AsyncConnection):
|
||||||
|
yield self.async_connection
|
||||||
|
elif isinstance(self.async_connection, AsyncConnectionPool):
|
||||||
|
async with self.async_connection.connection() as conn:
|
||||||
|
yield conn
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid async connection object. Please initialize the check pointer "
|
||||||
|
f"with an appropriate async connection object. "
|
||||||
|
f"Got {type(self.async_connection)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_schema(connection: psycopg.Connection, /) -> None:
|
||||||
|
"""Create the schema for the checkpoint saver."""
|
||||||
|
with connection.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
checkpoint BYTEA NOT NULL,
|
||||||
|
thread_ts TIMESTAMPTZ NOT NULL,
|
||||||
|
parent_ts TIMESTAMPTZ,
|
||||||
|
PRIMARY KEY (thread_id, thread_ts)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def acreate_schema(connection: psycopg.AsyncConnection, /) -> None:
|
||||||
|
"""Create the schema for the checkpoint saver."""
|
||||||
|
async with connection.cursor() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
checkpoint BYTEA NOT NULL,
|
||||||
|
thread_ts TIMESTAMPTZ NOT NULL,
|
||||||
|
parent_ts TIMESTAMPTZ,
|
||||||
|
PRIMARY KEY (thread_id, thread_ts)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def drop_schema(connection: psycopg.Connection, /) -> None:
|
||||||
|
"""Drop the table for the checkpoint saver."""
|
||||||
|
with connection.cursor() as cur:
|
||||||
|
cur.execute("DROP TABLE IF EXISTS checkpoints;")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def adrop_schema(connection: psycopg.AsyncConnection, /) -> None:
|
||||||
|
"""Drop the table for the checkpoint saver."""
|
||||||
|
async with connection.cursor() as cur:
|
||||||
|
await cur.execute("DROP TABLE IF EXISTS checkpoints;")
|
||||||
|
|
||||||
|
def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
|
||||||
|
"""Put the checkpoint for the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration for the checkpoint.
|
||||||
|
A dict with a `configurable` key which is a dict with
|
||||||
|
a `thread_id` key and an optional `thread_ts` key.
|
||||||
|
For example, { 'configurable': { 'thread_id': 'test_thread' } }
|
||||||
|
checkpoint: The checkpoint to persist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RunnableConfig that describes the checkpoint that was just created.
|
||||||
|
It'll contain the `thread_id` and `thread_ts` of the checkpoint.
|
||||||
|
"""
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
parent_ts = config["configurable"].get("thread_ts")
|
||||||
|
|
||||||
|
with self._get_sync_connection() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO checkpoints
|
||||||
|
(thread_id, thread_ts, parent_ts, checkpoint)
|
||||||
|
VALUES
|
||||||
|
(%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s)
|
||||||
|
ON CONFLICT (thread_id, thread_ts)
|
||||||
|
DO UPDATE SET checkpoint = EXCLUDED.checkpoint;
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": checkpoint["ts"],
|
||||||
|
"parent_ts": parent_ts if parent_ts else None,
|
||||||
|
"checkpoint": self.serializer.dumps(checkpoint),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": checkpoint["ts"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aput(
|
||||||
|
self, config: RunnableConfig, checkpoint: Checkpoint
|
||||||
|
) -> RunnableConfig:
|
||||||
|
"""Put the checkpoint for the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration for the checkpoint.
|
||||||
|
A dict with a `configurable` key which is a dict with
|
||||||
|
a `thread_id` key and an optional `thread_ts` key.
|
||||||
|
For example, { 'configurable': { 'thread_id': 'test_thread' } }
|
||||||
|
checkpoint: The checkpoint to persist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RunnableConfig that describes the checkpoint that was just created.
|
||||||
|
It'll contain the `thread_id` and `thread_ts` of the checkpoint.
|
||||||
|
"""
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
parent_ts = config["configurable"].get("thread_ts")
|
||||||
|
async with self._get_async_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO
|
||||||
|
checkpoints (thread_id, thread_ts, parent_ts, checkpoint)
|
||||||
|
VALUES
|
||||||
|
(%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s)
|
||||||
|
ON CONFLICT (thread_id, thread_ts)
|
||||||
|
DO UPDATE SET checkpoint = EXCLUDED.checkpoint;
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": checkpoint["ts"],
|
||||||
|
"parent_ts": parent_ts if parent_ts else None,
|
||||||
|
"checkpoint": self.serializer.dumps(checkpoint),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": checkpoint["ts"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None]:
|
||||||
|
"""Get all the checkpoints for the given configuration."""
|
||||||
|
with self._get_sync_connection() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
cur.execute(
|
||||||
|
"SELECT checkpoint, thread_ts, parent_ts "
|
||||||
|
"FROM checkpoints "
|
||||||
|
"WHERE thread_id = %(thread_id)s "
|
||||||
|
"ORDER BY thread_ts DESC",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for value in cur:
|
||||||
|
yield CheckpointTuple(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[1].isoformat(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
self.serializer.loads(value[0]),
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[2].isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value[2]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
|
||||||
|
"""Get all the checkpoints for the given configuration."""
|
||||||
|
async with self._get_async_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
await cur.execute(
|
||||||
|
"SELECT checkpoint, thread_ts, parent_ts "
|
||||||
|
"FROM checkpoints "
|
||||||
|
"WHERE thread_id = %(thread_id)s "
|
||||||
|
"ORDER BY thread_ts DESC",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async for value in cur:
|
||||||
|
yield CheckpointTuple(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[1].isoformat(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
self.serializer.loads(value[0]),
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[2].isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value[2]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
|
||||||
|
"""Get the checkpoint tuple for the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration for the checkpoint.
|
||||||
|
A dict with a `configurable` key which is a dict with
|
||||||
|
a `thread_id` key and an optional `thread_ts` key.
|
||||||
|
For example, { 'configurable': { 'thread_id': 'test_thread' } }
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The checkpoint tuple for the given configuration if it exists,
|
||||||
|
otherwise None.
|
||||||
|
|
||||||
|
If thread_ts is None, the latest checkpoint is returned if it exists.
|
||||||
|
"""
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
thread_ts = config["configurable"].get("thread_ts")
|
||||||
|
with self._get_sync_connection() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
if thread_ts:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT checkpoint, parent_ts "
|
||||||
|
"FROM checkpoints "
|
||||||
|
"WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": thread_ts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
value = cur.fetchone()
|
||||||
|
if value:
|
||||||
|
return CheckpointTuple(
|
||||||
|
config,
|
||||||
|
self.serializer.loads(value[0]),
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[1].isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value[1]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT checkpoint, thread_ts, parent_ts "
|
||||||
|
"FROM checkpoints "
|
||||||
|
"WHERE thread_id = %(thread_id)s "
|
||||||
|
"ORDER BY thread_ts DESC LIMIT 1",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
value = cur.fetchone()
|
||||||
|
if value:
|
||||||
|
return CheckpointTuple(
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[1].isoformat(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
checkpoint=self.serializer.loads(value[0]),
|
||||||
|
parent_config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[2].isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value[2]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
|
||||||
|
"""Get the checkpoint tuple for the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration for the checkpoint.
|
||||||
|
A dict with a `configurable` key which is a dict with
|
||||||
|
a `thread_id` key and an optional `thread_ts` key.
|
||||||
|
For example, { 'configurable': { 'thread_id': 'test_thread' } }
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The checkpoint tuple for the given configuration if it exists,
|
||||||
|
otherwise None.
|
||||||
|
|
||||||
|
If thread_ts is None, the latest checkpoint is returned if it exists.
|
||||||
|
"""
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
thread_ts = config["configurable"].get("thread_ts")
|
||||||
|
async with self._get_async_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
if thread_ts:
|
||||||
|
await cur.execute(
|
||||||
|
"SELECT checkpoint, parent_ts "
|
||||||
|
"FROM checkpoints "
|
||||||
|
"WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": thread_ts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
value = await cur.fetchone()
|
||||||
|
if value:
|
||||||
|
return CheckpointTuple(
|
||||||
|
config,
|
||||||
|
self.serializer.loads(value[0]),
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[1].isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value[1]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await cur.execute(
|
||||||
|
"SELECT checkpoint, thread_ts, parent_ts "
|
||||||
|
"FROM checkpoints "
|
||||||
|
"WHERE thread_id = %(thread_id)s "
|
||||||
|
"ORDER BY thread_ts DESC LIMIT 1",
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
value = await cur.fetchone()
|
||||||
|
if value:
|
||||||
|
return CheckpointTuple(
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[1].isoformat(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
checkpoint=self.serializer.loads(value[0]),
|
||||||
|
parent_config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"thread_ts": value[2].isoformat(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value[2]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
67
libs/partners/postgres/poetry.lock
generated
67
libs/partners/postgres/poetry.lock
generated
@ -11,37 +11,6 @@ files = [
|
|||||||
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
|
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "backports-zoneinfo"
|
|
||||||
version = "0.2.1"
|
|
||||||
description = "Backport of the standard library zoneinfo module"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.6"
|
|
||||||
files = [
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"},
|
|
||||||
{file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
tzdata = ["tzdata"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2024.2.2"
|
version = "2024.2.2"
|
||||||
@ -243,7 +212,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.37"
|
version = "0.1.40"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -256,7 +225,6 @@ langsmith = "^0.1.0"
|
|||||||
packaging = "^23.2"
|
packaging = "^23.2"
|
||||||
pydantic = ">=1,<3"
|
pydantic = ">=1,<3"
|
||||||
PyYAML = ">=5.3"
|
PyYAML = ">=5.3"
|
||||||
requests = "^2"
|
|
||||||
tenacity = "^8.1.0"
|
tenacity = "^8.1.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
@ -266,6 +234,20 @@ extended-testing = ["jinja2 (>=3,<4)"]
|
|||||||
type = "directory"
|
type = "directory"
|
||||||
url = "../../core"
|
url = "../../core"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "langgraph"
|
||||||
|
version = "0.0.32"
|
||||||
|
description = "langgraph"
|
||||||
|
optional = false
|
||||||
|
python-versions = "<4.0,>=3.9.0"
|
||||||
|
files = [
|
||||||
|
{file = "langgraph-0.0.32-py3-none-any.whl", hash = "sha256:b9330b75b420f6fc0b8b238c3dd974166e4e779fd11b6c73c58754db14644cb5"},
|
||||||
|
{file = "langgraph-0.0.32.tar.gz", hash = "sha256:28338cc525ae82b240de89bffec1bae412fedb4edb6267de5c7f944c47ea8263"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
langchain-core = ">=0.1.38,<0.2.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langsmith"
|
name = "langsmith"
|
||||||
version = "0.1.38"
|
version = "0.1.38"
|
||||||
@ -438,7 +420,6 @@ files = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
"backports.zoneinfo" = {version = ">=0.2.0", markers = "python_version < \"3.9\""}
|
|
||||||
typing-extensions = ">=4.1"
|
typing-extensions = ">=4.1"
|
||||||
tzdata = {version = "*", markers = "sys_platform == \"win32\""}
|
tzdata = {version = "*", markers = "sys_platform == \"win32\""}
|
||||||
|
|
||||||
@ -450,6 +431,20 @@ docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)"
|
|||||||
pool = ["psycopg-pool"]
|
pool = ["psycopg-pool"]
|
||||||
test = ["anyio (>=3.6.2,<4.0)", "mypy (>=1.4.1)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"]
|
test = ["anyio (>=3.6.2,<4.0)", "mypy (>=1.4.1)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "psycopg-pool"
|
||||||
|
version = "3.2.1"
|
||||||
|
description = "Connection Pool for Psycopg"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "psycopg-pool-3.2.1.tar.gz", hash = "sha256:6509a75c073590952915eddbba7ce8b8332a440a31e77bba69561483492829ad"},
|
||||||
|
{file = "psycopg_pool-3.2.1-py3-none-any.whl", hash = "sha256:060b551d1b97a8d358c668be58b637780b884de14d861f4f5ecc48b7563aafb7"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing-extensions = ">=4.4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.6.4"
|
version = "2.6.4"
|
||||||
@ -772,5 +767,5 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = "^3.9"
|
||||||
content-hash = "3ffcb2a37d4a25f6073fd59a7f2a14cdd89f03d847e651e9bcd8625426d28f50"
|
content-hash = "ee9808589dabaecefbb3b06d09e0c7a172116173ca9ea0de28263396793f377a"
|
||||||
|
@ -11,9 +11,11 @@ license = "MIT"
|
|||||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/postgres"
|
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/postgres"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = "^3.9"
|
||||||
langchain-core = "^0.1"
|
langchain-core = "^0.1"
|
||||||
psycopg = "^3.1.18"
|
psycopg = "^3.1.18"
|
||||||
|
langgraph = "^0.0.32"
|
||||||
|
psycopg-pool = "^3.2.1"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -0,0 +1,326 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from langgraph.checkpoint import Checkpoint
|
||||||
|
from langgraph.checkpoint.base import CheckpointTuple
|
||||||
|
|
||||||
|
from langchain_postgres.checkpoint import PickleCheckpointSerializer, PostgresCheckpoint
|
||||||
|
from tests.utils import asyncpg_client, syncpg_client
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_checkpoint() -> None:
|
||||||
|
"""Test the async chat history."""
|
||||||
|
async with asyncpg_client() as async_connection:
|
||||||
|
await PostgresCheckpoint.adrop_schema(async_connection)
|
||||||
|
await PostgresCheckpoint.acreate_schema(async_connection)
|
||||||
|
checkpoint_saver = PostgresCheckpoint(
|
||||||
|
async_connection=async_connection, serializer=PickleCheckpointSerializer()
|
||||||
|
)
|
||||||
|
checkpoint_tuple = [
|
||||||
|
c
|
||||||
|
async for c in checkpoint_saver.alist(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert len(checkpoint_tuple) == 0
|
||||||
|
|
||||||
|
# Add a checkpoint
|
||||||
|
sample_checkpoint: Checkpoint = {
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
await checkpoint_saver.aput(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
sample_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = [
|
||||||
|
c
|
||||||
|
async for c in checkpoint_saver.alist(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0].checkpoint == sample_checkpoint
|
||||||
|
|
||||||
|
# Add another checkpoint
|
||||||
|
sample_checkpoint2: Checkpoint = {
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-02T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
await checkpoint_saver.aput(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
sample_checkpoint2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try aget
|
||||||
|
checkpoints = [
|
||||||
|
c
|
||||||
|
async for c in checkpoint_saver.alist(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
# Should be sorted by timestamp desc
|
||||||
|
assert checkpoints[0].checkpoint == sample_checkpoint2
|
||||||
|
assert checkpoints[1].checkpoint == sample_checkpoint
|
||||||
|
|
||||||
|
assert await checkpoint_saver.aget_tuple(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) == CheckpointTuple(
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-02T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
checkpoint={
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-02T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": {}, # type: ignore
|
||||||
|
"versions_seen": {}, # type: ignore
|
||||||
|
},
|
||||||
|
parent_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check aget_tuple with thread_ts
|
||||||
|
assert await checkpoint_saver.aget_tuple(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) == CheckpointTuple(
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
checkpoint={
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": {}, # type: ignore
|
||||||
|
"versions_seen": {}, # type: ignore
|
||||||
|
},
|
||||||
|
parent_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_checkpoint() -> None:
|
||||||
|
"""Test the sync check point implementation."""
|
||||||
|
with syncpg_client() as sync_connection:
|
||||||
|
PostgresCheckpoint.drop_schema(sync_connection)
|
||||||
|
PostgresCheckpoint.create_schema(sync_connection)
|
||||||
|
checkpoint_saver = PostgresCheckpoint(
|
||||||
|
sync_connection=sync_connection, serializer=PickleCheckpointSerializer()
|
||||||
|
)
|
||||||
|
checkpoint_tuple = [
|
||||||
|
c
|
||||||
|
for c in checkpoint_saver.list(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert len(checkpoint_tuple) == 0
|
||||||
|
|
||||||
|
# Add a checkpoint
|
||||||
|
sample_checkpoint: Checkpoint = {
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_saver.put(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
sample_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = [
|
||||||
|
c
|
||||||
|
for c in checkpoint_saver.list(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0].checkpoint == sample_checkpoint
|
||||||
|
|
||||||
|
# Add another checkpoint
|
||||||
|
sample_checkpoint_2: Checkpoint = {
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-02T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_saver.put(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
sample_checkpoint_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try aget
|
||||||
|
checkpoints = [
|
||||||
|
c
|
||||||
|
for c in checkpoint_saver.list(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
# Should be sorted by timestamp desc
|
||||||
|
assert checkpoints[0].checkpoint == sample_checkpoint_2
|
||||||
|
assert checkpoints[1].checkpoint == sample_checkpoint
|
||||||
|
|
||||||
|
assert checkpoint_saver.get_tuple(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) == CheckpointTuple(
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-02T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
checkpoint={
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-02T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
},
|
||||||
|
parent_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_on_conflict_aput() -> None:
|
||||||
|
async with asyncpg_client() as async_connection:
|
||||||
|
await PostgresCheckpoint.adrop_schema(async_connection)
|
||||||
|
await PostgresCheckpoint.acreate_schema(async_connection)
|
||||||
|
checkpoint_saver = PostgresCheckpoint(
|
||||||
|
async_connection=async_connection, serializer=PickleCheckpointSerializer()
|
||||||
|
)
|
||||||
|
|
||||||
|
# aput with twice on the same (thread_id, thread_ts) should not raise any error
|
||||||
|
sample_checkpoint: Checkpoint = {
|
||||||
|
"v": 1,
|
||||||
|
"ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
}
|
||||||
|
new_checkpoint: Checkpoint = {
|
||||||
|
"v": 2,
|
||||||
|
"ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(),
|
||||||
|
"versions_seen": defaultdict(),
|
||||||
|
}
|
||||||
|
await checkpoint_saver.aput(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
sample_checkpoint,
|
||||||
|
)
|
||||||
|
await checkpoint_saver.aput(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new_checkpoint,
|
||||||
|
)
|
||||||
|
# Check aget_tuple with thread_ts
|
||||||
|
assert await checkpoint_saver.aget_tuple(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) == CheckpointTuple(
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
checkpoint={
|
||||||
|
"v": 2,
|
||||||
|
"ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
"channel_values": {},
|
||||||
|
"channel_versions": defaultdict(None, {}),
|
||||||
|
"versions_seen": defaultdict(None, {}),
|
||||||
|
},
|
||||||
|
parent_config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "test_thread",
|
||||||
|
"thread_ts": "2021-09-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
@ -1,7 +1,14 @@
|
|||||||
from langchain_postgres import __all__
|
from langchain_postgres import __all__
|
||||||
|
|
||||||
EXPECTED_ALL = ["__version__", "PostgresChatMessageHistory"]
|
EXPECTED_ALL = [
|
||||||
|
"__version__",
|
||||||
|
"CheckpointSerializer",
|
||||||
|
"PostgresChatMessageHistory",
|
||||||
|
"PostgresCheckpoint",
|
||||||
|
"PickleCheckpointSerializer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
def test_all_imports() -> None:
|
||||||
|
"""Test that __all__ is correctly defined."""
|
||||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
||||||
|
Loading…
Reference in New Issue
Block a user