From ba9e0d76c1691608c3ecc8510db1f53c7cb36dcd Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 8 Apr 2024 09:27:15 -0400 Subject: [PATCH] postgres[minor]: add postgres checkpoint implementation (#20025) Adds checkpoint implementation using psycopg --- libs/partners/postgres/README.md | 58 ++ .../postgres/langchain_postgres/__init__.py | 8 + .../postgres/langchain_postgres/checkpoint.py | 565 ++++++++++++++++++ libs/partners/postgres/poetry.lock | 67 +-- libs/partners/postgres/pyproject.toml | 4 +- .../integration_tests/test_checkpointer.py | 326 ++++++++++ .../postgres/tests/unit_tests/test_imports.py | 9 +- 7 files changed, 999 insertions(+), 38 deletions(-) create mode 100644 libs/partners/postgres/langchain_postgres/checkpoint.py create mode 100644 libs/partners/postgres/tests/integration_tests/test_checkpointer.py diff --git a/libs/partners/postgres/README.md b/libs/partners/postgres/README.md index 2552c1ff24a..55084835b4e 100644 --- a/libs/partners/postgres/README.md +++ b/libs/partners/postgres/README.md @@ -63,3 +63,61 @@ chat_history.add_messages([ print(chat_history.messages) ``` + + +### PostgresCheckpoint + +An implementation of the `Checkpoint` abstraction in LangGraph using Postgres. + + +Async Usage: + +```python +from psycopg_pool import AsyncConnectionPool +from langchain_postgres import ( + PostgresCheckpoint, PickleCheckpointSerializer +) + +pool = AsyncConnectionPool( + # Example configuration + conninfo="postgresql://user:password@localhost:5432/dbname", + max_size=20, +) + +# Uses the pickle module for serialization +# Make sure that you're only de-serializing trusted data +# (e.g., payloads that you have serialized yourself). +# Or implement a custom serializer. +checkpoint = PostgresCheckpoint( + serializer=PickleCheckpointSerializer(), + async_connection=pool, +) + +# Use the checkpoint object to put, get, list checkpoints, etc. +``` + +Sync Usage: + +```python +from psycopg_pool import ConnectionPool +from langchain_postgres import ( + PostgresCheckpoint, PickleCheckpointSerializer +) + +pool = ConnectionPool( + # Example configuration + conninfo="postgresql://user:password@localhost:5432/dbname", + max_size=20, +) + +# Uses the pickle module for serialization +# Make sure that you're only de-serializing trusted data +# (e.g., payloads that you have serialized yourself). +# Or implement a custom serializer. +checkpoint = PostgresCheckpoint( + serializer=PickleCheckpointSerializer(), + sync_connection=pool, +) + +# Use the checkpoint object to put, get, list checkpoints, etc. +``` diff --git a/libs/partners/postgres/langchain_postgres/__init__.py b/libs/partners/postgres/langchain_postgres/__init__.py index ba755ca05c6..ddda22e4803 100644 --- a/libs/partners/postgres/langchain_postgres/__init__.py +++ b/libs/partners/postgres/langchain_postgres/__init__.py @@ -1,6 +1,11 @@ from importlib import metadata from langchain_postgres.chat_message_histories import PostgresChatMessageHistory +from langchain_postgres.checkpoint import ( + CheckpointSerializer, + PickleCheckpointSerializer, + PostgresCheckpoint, +) try: __version__ = metadata.version(__package__) @@ -10,5 +15,8 @@ except metadata.PackageNotFoundError: __all__ = [ "__version__", + "CheckpointSerializer", "PostgresChatMessageHistory", + "PostgresCheckpoint", + "PickleCheckpointSerializer", ] diff --git a/libs/partners/postgres/langchain_postgres/checkpoint.py b/libs/partners/postgres/langchain_postgres/checkpoint.py new file mode 100644 index 00000000000..89a6972991a --- /dev/null +++ b/libs/partners/postgres/langchain_postgres/checkpoint.py @@ -0,0 +1,565 @@ +"""Implementation of a langgraph checkpoint saver using Postgres.""" +import abc +import pickle +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncGenerator, AsyncIterator, Generator, Optional, Union, cast + +import psycopg +from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig +from langgraph.checkpoint import BaseCheckpointSaver +from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple +from psycopg_pool import AsyncConnectionPool, ConnectionPool + + +class CheckpointSerializer(abc.ABC): + """A serializer for serializing and deserializing objects to and from bytes.""" + + @abc.abstractmethod + def dumps(self, obj: Checkpoint) -> bytes: + """Serialize an object to bytes.""" + + @abc.abstractmethod + def loads(self, data: bytes) -> Checkpoint: + """Deserialize an object from bytes.""" + + +class PickleCheckpointSerializer(CheckpointSerializer): + """Use the pickle module to serialize and deserialize objects. + + This serializer uses the pickle module to serialize and deserialize objects. + + While pickling can serialize a wide range of Python objects, it may fail + de-serializable objects upon updates of the Python version or the python + environment (e.g., the object's class definition changes in LangGraph). + + *Security Warning*: The pickle module can deserialize malicious payloads, + only use this serializer with trusted data; e.g., data that you + have serialized yourself and can guarantee the integrity of. + """ + + def dumps(self, obj: Checkpoint) -> bytes: + """Serialize an object to bytes.""" + return pickle.dumps(obj) + + def loads(self, data: bytes) -> Checkpoint: + """Deserialize an object from bytes.""" + return cast(Checkpoint, pickle.loads(data)) + + +class PostgresCheckpoint(BaseCheckpointSaver): + """LangGraph checkpoint saver for Postgres. + + This implementation of a checkpoint saver uses a Postgres database to save + and retrieve checkpoints. It uses the psycopg3 package to interact with the + Postgres database. + + The checkpoint accepts either a sync_connection in the form of a psycopg.Connection + or a psycopg.ConnectionPool object, or an async_connection in the form of a + psycopg.AsyncConnection or psycopg.AsyncConnectionPool object. + + Usage: + + 1. First time use: create schema in the database using the `create_schema` method or + the async version `acreate_schema` method. + 2. Create a PostgresCheckpoint object with a serializer and an appropriate + connection object. + It's recommended to use a connection pool object for the connection. + If using a connection object, you are responsible for closing the connection + when done. + + Examples: + + + Sync usage with a connection pool: + + .. code-block:: python + + from psycopg_pool import ConnectionPool + from langchain_postgres import ( + PostgresCheckpoint, PickleCheckpointSerializer + ) + + pool = ConnectionPool( + # Example configuration + conninfo="postgresql://user:password@localhost:5432/dbname", + max_size=20, + ) + + # Uses the pickle module for serialization + # Make sure that you're only de-serializing trusted data + # (e.g., payloads that you have serialized yourself). + # Or implement a custom serializer. + checkpoint = PostgresCheckpoint( + serializer=PickleCheckpointSerializer(), + sync_connection=pool, + ) + + # Use the checkpoint object to put, get, list checkpoints, etc. + + + Async usage with a connection pool: + + .. code-block:: python + + from psycopg_pool import AsyncConnectionPool + from langchain_postgres import ( + PostgresCheckpoint, PickleCheckpointSerializer + ) + + pool = AsyncConnectionPool( + # Example configuration + conninfo="postgresql://user:password@localhost:5432/dbname", + max_size=20, + ) + + # Uses the pickle module for serialization + # Make sure that you're only de-serializing trusted data + # (e.g., payloads that you have serialized yourself). + # Or implement a custom serializer. + checkpoint = PostgresCheckpoint( + serializer=PickleCheckpointSerializer(), + async_connection=pool, + ) + + # Use the checkpoint object to put, get, list checkpoints, etc. + + + Async usage with a connection object: + + .. code-block:: python + + from psycopg import AsyncConnection + from langchain_postgres import ( + PostgresCheckpoint, PickleCheckpointSerializer + ) + + conninfo="postgresql://user:password@localhost:5432/dbname" + # Take care of closing the connection when done + async with AsyncConnection(conninfo=conninfo) as conn: + # Uses the pickle module for serialization + # Make sure that you're only de-serializing trusted data + # (e.g., payloads that you have serialized yourself). + # Or implement a custom serializer. + checkpoint = PostgresCheckpoint( + serializer=PickleCheckpointSerializer(), + async_connection=conn, + ) + + # Use the checkpoint object to put, get, list checkpoints, etc. + ... + """ + + serializer: CheckpointSerializer + """The serializer for serializing and deserializing objects to and from bytes.""" + + sync_connection: Optional[Union[psycopg.Connection, ConnectionPool]] = None + """The synchronous connection or pool to the Postgres database. + + If providing a connection object, please ensure that the connection is open + and remember to close the connection when done. + """ + async_connection: Optional[ + Union[psycopg.AsyncConnection, AsyncConnectionPool] + ] = None + """The asynchronous connection or pool to the Postgres database. + + If providing a connection object, please ensure that the connection is open + and remember to close the connection when done. + """ + + class Config: + arbitrary_types_allowed = True + extra = "forbid" + + @property + def config_specs(self) -> list[ConfigurableFieldSpec]: + """Return the configuration specs for this runnable.""" + return [ + ConfigurableFieldSpec( + id="thread_id", + annotation=Optional[str], + name="Thread ID", + description=None, + default=None, + is_shared=True, + ), + CheckpointThreadTs, + ] + + @contextmanager + def _get_sync_connection(self) -> Generator[psycopg.Connection, None, None]: + """Get the connection to the Postgres database.""" + if isinstance(self.sync_connection, psycopg.Connection): + yield self.sync_connection + elif isinstance(self.sync_connection, ConnectionPool): + with self.sync_connection.connection() as conn: + yield conn + else: + raise ValueError( + "Invalid sync connection object. Please initialize the check pointer " + f"with an appropriate sync connection object. " + f"Got {type(self.sync_connection)}." + ) + + @asynccontextmanager + async def _get_async_connection( + self, + ) -> AsyncGenerator[psycopg.AsyncConnection, None]: + """Get the connection to the Postgres database.""" + if isinstance(self.async_connection, psycopg.AsyncConnection): + yield self.async_connection + elif isinstance(self.async_connection, AsyncConnectionPool): + async with self.async_connection.connection() as conn: + yield conn + else: + raise ValueError( + "Invalid async connection object. Please initialize the check pointer " + f"with an appropriate async connection object. " + f"Got {type(self.async_connection)}." + ) + + @staticmethod + def create_schema(connection: psycopg.Connection, /) -> None: + """Create the schema for the checkpoint saver.""" + with connection.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint BYTEA NOT NULL, + thread_ts TIMESTAMPTZ NOT NULL, + parent_ts TIMESTAMPTZ, + PRIMARY KEY (thread_id, thread_ts) + ); + """ + ) + + @staticmethod + async def acreate_schema(connection: psycopg.AsyncConnection, /) -> None: + """Create the schema for the checkpoint saver.""" + async with connection.cursor() as cur: + await cur.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint BYTEA NOT NULL, + thread_ts TIMESTAMPTZ NOT NULL, + parent_ts TIMESTAMPTZ, + PRIMARY KEY (thread_id, thread_ts) + ); + """ + ) + + @staticmethod + def drop_schema(connection: psycopg.Connection, /) -> None: + """Drop the table for the checkpoint saver.""" + with connection.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS checkpoints;") + + @staticmethod + async def adrop_schema(connection: psycopg.AsyncConnection, /) -> None: + """Drop the table for the checkpoint saver.""" + async with connection.cursor() as cur: + await cur.execute("DROP TABLE IF EXISTS checkpoints;") + + def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig: + """Put the checkpoint for the given configuration. + + Args: + config: The configuration for the checkpoint. + A dict with a `configurable` key which is a dict with + a `thread_id` key and an optional `thread_ts` key. + For example, { 'configurable': { 'thread_id': 'test_thread' } } + checkpoint: The checkpoint to persist. + + Returns: + The RunnableConfig that describes the checkpoint that was just created. + It'll contain the `thread_id` and `thread_ts` of the checkpoint. + """ + thread_id = config["configurable"]["thread_id"] + parent_ts = config["configurable"].get("thread_ts") + + with self._get_sync_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO checkpoints + (thread_id, thread_ts, parent_ts, checkpoint) + VALUES + (%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s) + ON CONFLICT (thread_id, thread_ts) + DO UPDATE SET checkpoint = EXCLUDED.checkpoint; + """, + { + "thread_id": thread_id, + "thread_ts": checkpoint["ts"], + "parent_ts": parent_ts if parent_ts else None, + "checkpoint": self.serializer.dumps(checkpoint), + }, + ) + + return { + "configurable": { + "thread_id": thread_id, + "thread_ts": checkpoint["ts"], + }, + } + + async def aput( + self, config: RunnableConfig, checkpoint: Checkpoint + ) -> RunnableConfig: + """Put the checkpoint for the given configuration. + + Args: + config: The configuration for the checkpoint. + A dict with a `configurable` key which is a dict with + a `thread_id` key and an optional `thread_ts` key. + For example, { 'configurable': { 'thread_id': 'test_thread' } } + checkpoint: The checkpoint to persist. + + Returns: + The RunnableConfig that describes the checkpoint that was just created. + It'll contain the `thread_id` and `thread_ts` of the checkpoint. + """ + thread_id = config["configurable"]["thread_id"] + parent_ts = config["configurable"].get("thread_ts") + async with self._get_async_connection() as conn: + async with conn.cursor() as cur: + await cur.execute( + """ + INSERT INTO + checkpoints (thread_id, thread_ts, parent_ts, checkpoint) + VALUES + (%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s) + ON CONFLICT (thread_id, thread_ts) + DO UPDATE SET checkpoint = EXCLUDED.checkpoint; + """, + { + "thread_id": thread_id, + "thread_ts": checkpoint["ts"], + "parent_ts": parent_ts if parent_ts else None, + "checkpoint": self.serializer.dumps(checkpoint), + }, + ) + + return { + "configurable": { + "thread_id": thread_id, + "thread_ts": checkpoint["ts"], + }, + } + + def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None]: + """Get all the checkpoints for the given configuration.""" + with self._get_sync_connection() as conn: + with conn.cursor() as cur: + thread_id = config["configurable"]["thread_id"] + cur.execute( + "SELECT checkpoint, thread_ts, parent_ts " + "FROM checkpoints " + "WHERE thread_id = %(thread_id)s " + "ORDER BY thread_ts DESC", + { + "thread_id": thread_id, + }, + ) + for value in cur: + yield CheckpointTuple( + { + "configurable": { + "thread_id": thread_id, + "thread_ts": value[1].isoformat(), + } + }, + self.serializer.loads(value[0]), + { + "configurable": { + "thread_id": thread_id, + "thread_ts": value[2].isoformat(), + } + } + if value[2] + else None, + ) + + async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: + """Get all the checkpoints for the given configuration.""" + async with self._get_async_connection() as conn: + async with conn.cursor() as cur: + thread_id = config["configurable"]["thread_id"] + await cur.execute( + "SELECT checkpoint, thread_ts, parent_ts " + "FROM checkpoints " + "WHERE thread_id = %(thread_id)s " + "ORDER BY thread_ts DESC", + { + "thread_id": thread_id, + }, + ) + async for value in cur: + yield CheckpointTuple( + { + "configurable": { + "thread_id": thread_id, + "thread_ts": value[1].isoformat(), + } + }, + self.serializer.loads(value[0]), + { + "configurable": { + "thread_id": thread_id, + "thread_ts": value[2].isoformat(), + } + } + if value[2] + else None, + ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get the checkpoint tuple for the given configuration. + + Args: + config: The configuration for the checkpoint. + A dict with a `configurable` key which is a dict with + a `thread_id` key and an optional `thread_ts` key. + For example, { 'configurable': { 'thread_id': 'test_thread' } } + + Returns: + The checkpoint tuple for the given configuration if it exists, + otherwise None. + + If thread_ts is None, the latest checkpoint is returned if it exists. + """ + thread_id = config["configurable"]["thread_id"] + thread_ts = config["configurable"].get("thread_ts") + with self._get_sync_connection() as conn: + with conn.cursor() as cur: + if thread_ts: + cur.execute( + "SELECT checkpoint, parent_ts " + "FROM checkpoints " + "WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s", + { + "thread_id": thread_id, + "thread_ts": thread_ts, + }, + ) + value = cur.fetchone() + if value: + return CheckpointTuple( + config, + self.serializer.loads(value[0]), + { + "configurable": { + "thread_id": thread_id, + "thread_ts": value[1].isoformat(), + } + } + if value[1] + else None, + ) + else: + cur.execute( + "SELECT checkpoint, thread_ts, parent_ts " + "FROM checkpoints " + "WHERE thread_id = %(thread_id)s " + "ORDER BY thread_ts DESC LIMIT 1", + { + "thread_id": thread_id, + }, + ) + value = cur.fetchone() + if value: + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "thread_ts": value[1].isoformat(), + } + }, + checkpoint=self.serializer.loads(value[0]), + parent_config={ + "configurable": { + "thread_id": thread_id, + "thread_ts": value[2].isoformat(), + } + } + if value[2] + else None, + ) + return None + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get the checkpoint tuple for the given configuration. + + Args: + config: The configuration for the checkpoint. + A dict with a `configurable` key which is a dict with + a `thread_id` key and an optional `thread_ts` key. + For example, { 'configurable': { 'thread_id': 'test_thread' } } + + Returns: + The checkpoint tuple for the given configuration if it exists, + otherwise None. + + If thread_ts is None, the latest checkpoint is returned if it exists. + """ + thread_id = config["configurable"]["thread_id"] + thread_ts = config["configurable"].get("thread_ts") + async with self._get_async_connection() as conn: + async with conn.cursor() as cur: + if thread_ts: + await cur.execute( + "SELECT checkpoint, parent_ts " + "FROM checkpoints " + "WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s", + { + "thread_id": thread_id, + "thread_ts": thread_ts, + }, + ) + value = await cur.fetchone() + if value: + return CheckpointTuple( + config, + self.serializer.loads(value[0]), + { + "configurable": { + "thread_id": thread_id, + "thread_ts": value[1].isoformat(), + } + } + if value[1] + else None, + ) + else: + await cur.execute( + "SELECT checkpoint, thread_ts, parent_ts " + "FROM checkpoints " + "WHERE thread_id = %(thread_id)s " + "ORDER BY thread_ts DESC LIMIT 1", + { + "thread_id": thread_id, + }, + ) + value = await cur.fetchone() + if value: + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "thread_ts": value[1].isoformat(), + } + }, + checkpoint=self.serializer.loads(value[0]), + parent_config={ + "configurable": { + "thread_id": thread_id, + "thread_ts": value[2].isoformat(), + } + } + if value[2] + else None, + ) + + return None diff --git a/libs/partners/postgres/poetry.lock b/libs/partners/postgres/poetry.lock index 51017dac2b7..6c4ff0070d9 100644 --- a/libs/partners/postgres/poetry.lock +++ b/libs/partners/postgres/poetry.lock @@ -11,37 +11,6 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - -[[package]] -name = "backports-zoneinfo" -version = "0.2.1" -description = "Backport of the standard library zoneinfo module" -optional = false -python-versions = ">=3.6" -files = [ - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, - {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, -] - -[package.extras] -tzdata = ["tzdata"] - [[package]] name = "certifi" version = "2024.2.2" @@ -243,7 +212,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.37" +version = "0.1.40" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -256,7 +225,6 @@ langsmith = "^0.1.0" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" -requests = "^2" tenacity = "^8.1.0" [package.extras] @@ -266,6 +234,20 @@ extended-testing = ["jinja2 (>=3,<4)"] type = "directory" url = "../../core" +[[package]] +name = "langgraph" +version = "0.0.32" +description = "langgraph" +optional = false +python-versions = "<4.0,>=3.9.0" +files = [ + {file = "langgraph-0.0.32-py3-none-any.whl", hash = "sha256:b9330b75b420f6fc0b8b238c3dd974166e4e779fd11b6c73c58754db14644cb5"}, + {file = "langgraph-0.0.32.tar.gz", hash = "sha256:28338cc525ae82b240de89bffec1bae412fedb4edb6267de5c7f944c47ea8263"}, +] + +[package.dependencies] +langchain-core = ">=0.1.38,<0.2.0" + [[package]] name = "langsmith" version = "0.1.38" @@ -438,7 +420,6 @@ files = [ ] [package.dependencies] -"backports.zoneinfo" = {version = ">=0.2.0", markers = "python_version < \"3.9\""} typing-extensions = ">=4.1" tzdata = {version = "*", markers = "sys_platform == \"win32\""} @@ -450,6 +431,20 @@ docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)" pool = ["psycopg-pool"] test = ["anyio (>=3.6.2,<4.0)", "mypy (>=1.4.1)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] +[[package]] +name = "psycopg-pool" +version = "3.2.1" +description = "Connection Pool for Psycopg" +optional = false +python-versions = ">=3.8" +files = [ + {file = "psycopg-pool-3.2.1.tar.gz", hash = "sha256:6509a75c073590952915eddbba7ce8b8332a440a31e77bba69561483492829ad"}, + {file = "psycopg_pool-3.2.1-py3-none-any.whl", hash = "sha256:060b551d1b97a8d358c668be58b637780b884de14d861f4f5ecc48b7563aafb7"}, +] + +[package.dependencies] +typing-extensions = ">=4.4" + [[package]] name = "pydantic" version = "2.6.4" @@ -772,5 +767,5 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" -python-versions = ">=3.8.1,<4.0" -content-hash = "3ffcb2a37d4a25f6073fd59a7f2a14cdd89f03d847e651e9bcd8625426d28f50" +python-versions = "^3.9" +content-hash = "ee9808589dabaecefbb3b06d09e0c7a172116173ca9ea0de28263396793f377a" diff --git a/libs/partners/postgres/pyproject.toml b/libs/partners/postgres/pyproject.toml index 008335add63..c9a598d5ca6 100644 --- a/libs/partners/postgres/pyproject.toml +++ b/libs/partners/postgres/pyproject.toml @@ -11,9 +11,11 @@ license = "MIT" "Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/postgres" [tool.poetry.dependencies] -python = ">=3.8.1,<4.0" +python = "^3.9" langchain-core = "^0.1" psycopg = "^3.1.18" +langgraph = "^0.0.32" +psycopg-pool = "^3.2.1" [tool.poetry.group.test] optional = true diff --git a/libs/partners/postgres/tests/integration_tests/test_checkpointer.py b/libs/partners/postgres/tests/integration_tests/test_checkpointer.py new file mode 100644 index 00000000000..1179d8b8f7d --- /dev/null +++ b/libs/partners/postgres/tests/integration_tests/test_checkpointer.py @@ -0,0 +1,326 @@ +from collections import defaultdict + +from langgraph.checkpoint import Checkpoint +from langgraph.checkpoint.base import CheckpointTuple + +from langchain_postgres.checkpoint import PickleCheckpointSerializer, PostgresCheckpoint +from tests.utils import asyncpg_client, syncpg_client + + +async def test_async_checkpoint() -> None: + """Test the async chat history.""" + async with asyncpg_client() as async_connection: + await PostgresCheckpoint.adrop_schema(async_connection) + await PostgresCheckpoint.acreate_schema(async_connection) + checkpoint_saver = PostgresCheckpoint( + async_connection=async_connection, serializer=PickleCheckpointSerializer() + ) + checkpoint_tuple = [ + c + async for c in checkpoint_saver.alist( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) + ] + assert len(checkpoint_tuple) == 0 + + # Add a checkpoint + sample_checkpoint: Checkpoint = { + "v": 1, + "ts": "2021-09-01T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + } + + await checkpoint_saver.aput( + { + "configurable": { + "thread_id": "test_thread", + } + }, + sample_checkpoint, + ) + + checkpoints = [ + c + async for c in checkpoint_saver.alist( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) + ] + + assert len(checkpoints) == 1 + assert checkpoints[0].checkpoint == sample_checkpoint + + # Add another checkpoint + sample_checkpoint2: Checkpoint = { + "v": 1, + "ts": "2021-09-02T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + } + + await checkpoint_saver.aput( + { + "configurable": { + "thread_id": "test_thread", + } + }, + sample_checkpoint2, + ) + + # Try aget + checkpoints = [ + c + async for c in checkpoint_saver.alist( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) + ] + + assert len(checkpoints) == 2 + # Should be sorted by timestamp desc + assert checkpoints[0].checkpoint == sample_checkpoint2 + assert checkpoints[1].checkpoint == sample_checkpoint + + assert await checkpoint_saver.aget_tuple( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) == CheckpointTuple( + config={ + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-02T00:00:00+00:00", + } + }, + checkpoint={ + "v": 1, + "ts": "2021-09-02T00:00:00+00:00", + "channel_values": {}, + "channel_versions": {}, # type: ignore + "versions_seen": {}, # type: ignore + }, + parent_config=None, + ) + + # Check aget_tuple with thread_ts + assert await checkpoint_saver.aget_tuple( + { + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + } + ) == CheckpointTuple( + config={ + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + }, + checkpoint={ + "v": 1, + "ts": "2021-09-01T00:00:00+00:00", + "channel_values": {}, + "channel_versions": {}, # type: ignore + "versions_seen": {}, # type: ignore + }, + parent_config=None, + ) + + +def test_sync_checkpoint() -> None: + """Test the sync check point implementation.""" + with syncpg_client() as sync_connection: + PostgresCheckpoint.drop_schema(sync_connection) + PostgresCheckpoint.create_schema(sync_connection) + checkpoint_saver = PostgresCheckpoint( + sync_connection=sync_connection, serializer=PickleCheckpointSerializer() + ) + checkpoint_tuple = [ + c + for c in checkpoint_saver.list( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) + ] + assert len(checkpoint_tuple) == 0 + + # Add a checkpoint + sample_checkpoint: Checkpoint = { + "v": 1, + "ts": "2021-09-01T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + } + + checkpoint_saver.put( + { + "configurable": { + "thread_id": "test_thread", + } + }, + sample_checkpoint, + ) + + checkpoints = [ + c + for c in checkpoint_saver.list( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) + ] + + assert len(checkpoints) == 1 + assert checkpoints[0].checkpoint == sample_checkpoint + + # Add another checkpoint + sample_checkpoint_2: Checkpoint = { + "v": 1, + "ts": "2021-09-02T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + } + + checkpoint_saver.put( + { + "configurable": { + "thread_id": "test_thread", + } + }, + sample_checkpoint_2, + ) + + # Try aget + checkpoints = [ + c + for c in checkpoint_saver.list( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) + ] + + assert len(checkpoints) == 2 + # Should be sorted by timestamp desc + assert checkpoints[0].checkpoint == sample_checkpoint_2 + assert checkpoints[1].checkpoint == sample_checkpoint + + assert checkpoint_saver.get_tuple( + { + "configurable": { + "thread_id": "test_thread", + } + } + ) == CheckpointTuple( + config={ + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-02T00:00:00+00:00", + } + }, + checkpoint={ + "v": 1, + "ts": "2021-09-02T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + }, + parent_config=None, + ) + + +async def test_on_conflict_aput() -> None: + async with asyncpg_client() as async_connection: + await PostgresCheckpoint.adrop_schema(async_connection) + await PostgresCheckpoint.acreate_schema(async_connection) + checkpoint_saver = PostgresCheckpoint( + async_connection=async_connection, serializer=PickleCheckpointSerializer() + ) + + # aput with twice on the same (thread_id, thread_ts) should not raise any error + sample_checkpoint: Checkpoint = { + "v": 1, + "ts": "2021-09-01T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + } + new_checkpoint: Checkpoint = { + "v": 2, + "ts": "2021-09-01T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(), + "versions_seen": defaultdict(), + } + await checkpoint_saver.aput( + { + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + }, + sample_checkpoint, + ) + await checkpoint_saver.aput( + { + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + }, + new_checkpoint, + ) + # Check aget_tuple with thread_ts + assert await checkpoint_saver.aget_tuple( + { + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + } + ) == CheckpointTuple( + config={ + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + }, + checkpoint={ + "v": 2, + "ts": "2021-09-01T00:00:00+00:00", + "channel_values": {}, + "channel_versions": defaultdict(None, {}), + "versions_seen": defaultdict(None, {}), + }, + parent_config={ + "configurable": { + "thread_id": "test_thread", + "thread_ts": "2021-09-01T00:00:00+00:00", + } + }, + ) diff --git a/libs/partners/postgres/tests/unit_tests/test_imports.py b/libs/partners/postgres/tests/unit_tests/test_imports.py index 59225c5b07d..761a273c1da 100644 --- a/libs/partners/postgres/tests/unit_tests/test_imports.py +++ b/libs/partners/postgres/tests/unit_tests/test_imports.py @@ -1,7 +1,14 @@ from langchain_postgres import __all__ -EXPECTED_ALL = ["__version__", "PostgresChatMessageHistory"] +EXPECTED_ALL = [ + "__version__", + "CheckpointSerializer", + "PostgresChatMessageHistory", + "PostgresCheckpoint", + "PickleCheckpointSerializer", +] def test_all_imports() -> None: + """Test that __all__ is correctly defined.""" assert sorted(EXPECTED_ALL) == sorted(__all__)