mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 02:33:05 +00:00
refactor: enable connection pool usage in PGVector (#11514)
- **Description:** `PGVector` refactored to use connection pool. - **Issue:** #11433, - **Tag maintainer:** @hwchase17 @eyurtsev, --------- Co-authored-by: Diego Rani Mazine <diego.mazine@mercadolivre.com> Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
parent
507c195a4b
commit
ec72225265
@ -203,8 +203,7 @@ class PGVector(VectorStore):
|
|||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
self.override_relevance_score_fn = relevance_score_fn
|
self.override_relevance_score_fn = relevance_score_fn
|
||||||
self.engine_args = engine_args or {}
|
self.engine_args = engine_args or {}
|
||||||
# Create a connection if not provided, otherwise use the provided connection
|
self._bind = connection if connection else self._create_engine()
|
||||||
self._conn = connection if connection else self.connect()
|
|
||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def __post_init__(
|
def __post_init__(
|
||||||
@ -220,21 +219,19 @@ class PGVector(VectorStore):
|
|||||||
self.create_collection()
|
self.create_collection()
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self._conn:
|
if isinstance(self._bind, sqlalchemy.engine.Connection):
|
||||||
self._conn.close()
|
self._bind.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
return self.embedding_function
|
return self.embedding_function
|
||||||
|
|
||||||
def connect(self) -> sqlalchemy.engine.Connection:
|
def _create_engine(self) -> sqlalchemy.engine.Engine:
|
||||||
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
return sqlalchemy.create_engine(url=self.connection_string, **self.engine_args)
|
||||||
conn = engine.connect()
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def create_vector_extension(self) -> None:
|
def create_vector_extension(self) -> None:
|
||||||
try:
|
try:
|
||||||
with Session(self._conn) as session:
|
with Session(self._bind) as session:
|
||||||
# The advisor lock fixes issue arising from concurrent
|
# The advisor lock fixes issue arising from concurrent
|
||||||
# creation of the vector extension.
|
# creation of the vector extension.
|
||||||
# https://github.com/langchain-ai/langchain/issues/12933
|
# https://github.com/langchain-ai/langchain/issues/12933
|
||||||
@ -252,24 +249,24 @@ class PGVector(VectorStore):
|
|||||||
raise Exception(f"Failed to create vector extension: {e}") from e
|
raise Exception(f"Failed to create vector extension: {e}") from e
|
||||||
|
|
||||||
def create_tables_if_not_exists(self) -> None:
|
def create_tables_if_not_exists(self) -> None:
|
||||||
with self._conn.begin():
|
with Session(self._bind) as session, session.begin():
|
||||||
Base.metadata.create_all(self._conn)
|
Base.metadata.create_all(session.get_bind())
|
||||||
|
|
||||||
def drop_tables(self) -> None:
|
def drop_tables(self) -> None:
|
||||||
with self._conn.begin():
|
with Session(self._bind) as session, session.begin():
|
||||||
Base.metadata.drop_all(self._conn)
|
Base.metadata.drop_all(session.get_bind())
|
||||||
|
|
||||||
def create_collection(self) -> None:
|
def create_collection(self) -> None:
|
||||||
if self.pre_delete_collection:
|
if self.pre_delete_collection:
|
||||||
self.delete_collection()
|
self.delete_collection()
|
||||||
with Session(self._conn) as session:
|
with Session(self._bind) as session:
|
||||||
self.CollectionStore.get_or_create(
|
self.CollectionStore.get_or_create(
|
||||||
session, self.collection_name, cmetadata=self.collection_metadata
|
session, self.collection_name, cmetadata=self.collection_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_collection(self) -> None:
|
def delete_collection(self) -> None:
|
||||||
self.logger.debug("Trying to delete collection")
|
self.logger.debug("Trying to delete collection")
|
||||||
with Session(self._conn) as session:
|
with Session(self._bind) as session:
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
self.logger.warning("Collection not found")
|
self.logger.warning("Collection not found")
|
||||||
@ -280,7 +277,7 @@ class PGVector(VectorStore):
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _make_session(self) -> Generator[Session, None, None]:
|
def _make_session(self) -> Generator[Session, None, None]:
|
||||||
"""Create a context manager for the session, bind to _conn string."""
|
"""Create a context manager for the session, bind to _conn string."""
|
||||||
yield Session(self._conn)
|
yield Session(self._bind)
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
@ -292,7 +289,7 @@ class PGVector(VectorStore):
|
|||||||
Args:
|
Args:
|
||||||
ids: List of ids to delete.
|
ids: List of ids to delete.
|
||||||
"""
|
"""
|
||||||
with Session(self._conn) as session:
|
with Session(self._bind) as session:
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Trying to delete vectors by ids (represented by the model "
|
"Trying to delete vectors by ids (represented by the model "
|
||||||
@ -366,7 +363,7 @@ class PGVector(VectorStore):
|
|||||||
if not metadatas:
|
if not metadatas:
|
||||||
metadatas = [{} for _ in texts]
|
metadatas = [{} for _ in texts]
|
||||||
|
|
||||||
with Session(self._conn) as session:
|
with Session(self._bind) as session:
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
@ -496,7 +493,7 @@ class PGVector(VectorStore):
|
|||||||
filter: Optional[Dict[str, str]] = None,
|
filter: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""Query the collection."""
|
"""Query the collection."""
|
||||||
with Session(self._conn) as session:
|
with Session(self._bind) as session:
|
||||||
collection = self.get_collection(session)
|
collection = self.get_collection(session)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
version: "3.8"
|
||||||
|
|
||||||
|
services:
|
||||||
|
pgvector:
|
||||||
|
image: ankane/pgvector:latest
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: ${PGVECTOR_DB:-postgres}
|
||||||
|
POSTGRES_USER: ${PGVECTOR_USER:-postgres}
|
||||||
|
POSTGRES_PASSWORD: ${PGVECTOR_PASSWORD:-postgres}
|
||||||
|
ports:
|
||||||
|
- ${PGVECTOR_PORT:-5432}:5432
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:5432"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
@ -2,6 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -155,7 +156,7 @@ def test_pgvector_collection_with_metadata() -> None:
|
|||||||
connection_string=CONNECTION_STRING,
|
connection_string=CONNECTION_STRING,
|
||||||
pre_delete_collection=True,
|
pre_delete_collection=True,
|
||||||
)
|
)
|
||||||
session = Session(pgvector.connect())
|
session = Session(pgvector._create_engine())
|
||||||
collection = pgvector.get_collection(session)
|
collection = pgvector.get_collection(session)
|
||||||
if collection is None:
|
if collection is None:
|
||||||
assert False, "Expected a CollectionStore object but received None"
|
assert False, "Expected a CollectionStore object but received None"
|
||||||
@ -327,3 +328,43 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None:
|
|||||||
)
|
)
|
||||||
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
|
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
|
||||||
assert output == [(Document(page_content="foo"), 0.0)]
|
assert output == [(Document(page_content="foo"), 0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_with_custom_connection() -> None:
|
||||||
|
"""Test construction using a custom connection."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
engine = sqlalchemy.create_engine(CONNECTION_STRING)
|
||||||
|
with engine.connect() as connection:
|
||||||
|
docsearch = PGVector.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
connection=connection,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_with_custom_engine_args() -> None:
|
||||||
|
"""Test construction using custom engine arguments."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
engine_args = {
|
||||||
|
"pool_size": 5,
|
||||||
|
"max_overflow": 10,
|
||||||
|
"pool_recycle": -1,
|
||||||
|
"pool_use_lifo": False,
|
||||||
|
"pool_pre_ping": False,
|
||||||
|
"pool_timeout": 30,
|
||||||
|
}
|
||||||
|
docsearch = PGVector.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
engine_args=engine_args,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
"""Test PGVector functionality."""
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_community.embeddings import FakeEmbeddings
|
||||||
|
from langchain_community.vectorstores import pgvector
|
||||||
|
|
||||||
|
_CONNECTION_STRING = pgvector.PGVector.connection_string_from_db_params(
|
||||||
|
driver="psycopg2",
|
||||||
|
host="localhost",
|
||||||
|
port=5432,
|
||||||
|
database="postgres",
|
||||||
|
user="postgres",
|
||||||
|
password="postgres",
|
||||||
|
)
|
||||||
|
|
||||||
|
_EMBEDDING_FUNCTION = FakeEmbeddings(size=1536)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("pgvector")
|
||||||
|
@mock.patch("sqlalchemy.create_engine")
|
||||||
|
def test_given_a_connection_is_provided_then_no_engine_should_be_created(
|
||||||
|
create_engine: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""When a connection is provided then no engine should be created."""
|
||||||
|
pgvector.PGVector(
|
||||||
|
connection_string=_CONNECTION_STRING,
|
||||||
|
embedding_function=_EMBEDDING_FUNCTION,
|
||||||
|
connection=mock.MagicMock(),
|
||||||
|
)
|
||||||
|
create_engine.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("pgvector")
|
||||||
|
@mock.patch("sqlalchemy.create_engine")
|
||||||
|
def test_given_no_connection_or_engine_args_provided_default_engine_should_be_used(
|
||||||
|
create_engine: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""When no connection or engine arguments are provided then the default configuration must be used.""" # noqa: E501
|
||||||
|
pgvector.PGVector(
|
||||||
|
connection_string=_CONNECTION_STRING,
|
||||||
|
embedding_function=_EMBEDDING_FUNCTION,
|
||||||
|
)
|
||||||
|
create_engine.assert_called_with(
|
||||||
|
url=_CONNECTION_STRING,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("pgvector")
|
||||||
|
@mock.patch("sqlalchemy.create_engine")
|
||||||
|
def test_given_engine_args_are_provided_then_they_should_be_used(
|
||||||
|
create_engine: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""When engine arguments are provided then they must be used to create the underlying engine.""" # noqa: E501
|
||||||
|
engine_args = {
|
||||||
|
"pool_size": 5,
|
||||||
|
"max_overflow": 10,
|
||||||
|
"pool_recycle": -1,
|
||||||
|
"pool_use_lifo": False,
|
||||||
|
"pool_pre_ping": False,
|
||||||
|
"pool_timeout": 30,
|
||||||
|
}
|
||||||
|
pgvector.PGVector(
|
||||||
|
connection_string=_CONNECTION_STRING,
|
||||||
|
embedding_function=_EMBEDDING_FUNCTION,
|
||||||
|
engine_args=engine_args,
|
||||||
|
)
|
||||||
|
create_engine.assert_called_with(
|
||||||
|
url=_CONNECTION_STRING,
|
||||||
|
**engine_args,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user