mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
refactor: enable connection pool usage in PGVector (#11514)
- **Description:** `PGVector` refactored to use connection pool. - **Issue:** #11433, - **Tag maintainer:** @hwchase17 @eyurtsev, --------- Co-authored-by: Diego Rani Mazine <diego.mazine@mercadolivre.com> Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
parent
507c195a4b
commit
ec72225265
@ -203,8 +203,7 @@ class PGVector(VectorStore):
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self.engine_args = engine_args or {}
|
||||
# Create a connection if not provided, otherwise use the provided connection
|
||||
self._conn = connection if connection else self.connect()
|
||||
self._bind = connection if connection else self._create_engine()
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(
|
||||
@ -220,21 +219,19 @@ class PGVector(VectorStore):
|
||||
self.create_collection()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
if isinstance(self._bind, sqlalchemy.engine.Connection):
|
||||
self._bind.close()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_function
|
||||
|
||||
def connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
||||
conn = engine.connect()
|
||||
return conn
|
||||
def _create_engine(self) -> sqlalchemy.engine.Engine:
|
||||
return sqlalchemy.create_engine(url=self.connection_string, **self.engine_args)
|
||||
|
||||
def create_vector_extension(self) -> None:
|
||||
try:
|
||||
with Session(self._conn) as session:
|
||||
with Session(self._bind) as session:
|
||||
# The advisor lock fixes issue arising from concurrent
|
||||
# creation of the vector extension.
|
||||
# https://github.com/langchain-ai/langchain/issues/12933
|
||||
@ -252,24 +249,24 @@ class PGVector(VectorStore):
|
||||
raise Exception(f"Failed to create vector extension: {e}") from e
|
||||
|
||||
def create_tables_if_not_exists(self) -> None:
|
||||
with self._conn.begin():
|
||||
Base.metadata.create_all(self._conn)
|
||||
with Session(self._bind) as session, session.begin():
|
||||
Base.metadata.create_all(session.get_bind())
|
||||
|
||||
def drop_tables(self) -> None:
|
||||
with self._conn.begin():
|
||||
Base.metadata.drop_all(self._conn)
|
||||
with Session(self._bind) as session, session.begin():
|
||||
Base.metadata.drop_all(session.get_bind())
|
||||
|
||||
def create_collection(self) -> None:
|
||||
if self.pre_delete_collection:
|
||||
self.delete_collection()
|
||||
with Session(self._conn) as session:
|
||||
with Session(self._bind) as session:
|
||||
self.CollectionStore.get_or_create(
|
||||
session, self.collection_name, cmetadata=self.collection_metadata
|
||||
)
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
self.logger.debug("Trying to delete collection")
|
||||
with Session(self._conn) as session:
|
||||
with Session(self._bind) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
self.logger.warning("Collection not found")
|
||||
@ -280,7 +277,7 @@ class PGVector(VectorStore):
|
||||
@contextlib.contextmanager
|
||||
def _make_session(self) -> Generator[Session, None, None]:
|
||||
"""Create a context manager for the session, bind to _conn string."""
|
||||
yield Session(self._conn)
|
||||
yield Session(self._bind)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@ -292,7 +289,7 @@ class PGVector(VectorStore):
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
"""
|
||||
with Session(self._conn) as session:
|
||||
with Session(self._bind) as session:
|
||||
if ids is not None:
|
||||
self.logger.debug(
|
||||
"Trying to delete vectors by ids (represented by the model "
|
||||
@ -366,7 +363,7 @@ class PGVector(VectorStore):
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
with Session(self._conn) as session:
|
||||
with Session(self._bind) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
@ -496,7 +493,7 @@ class PGVector(VectorStore):
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Any]:
|
||||
"""Query the collection."""
|
||||
with Session(self._conn) as session:
|
||||
with Session(self._bind) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
|
@ -0,0 +1,17 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
pgvector:
|
||||
image: ankane/pgvector:latest
|
||||
environment:
|
||||
POSTGRES_DB: ${PGVECTOR_DB:-postgres}
|
||||
POSTGRES_USER: ${PGVECTOR_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${PGVECTOR_PASSWORD:-postgres}
|
||||
ports:
|
||||
- ${PGVECTOR_PORT:-5432}:5432
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:5432"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
@ -2,6 +2,7 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import sqlalchemy
|
||||
from langchain_core.documents import Document
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -155,7 +156,7 @@ def test_pgvector_collection_with_metadata() -> None:
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
session = Session(pgvector.connect())
|
||||
session = Session(pgvector._create_engine())
|
||||
collection = pgvector.get_collection(session)
|
||||
if collection is None:
|
||||
assert False, "Expected a CollectionStore object but received None"
|
||||
@ -327,3 +328,43 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None:
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
|
||||
assert output == [(Document(page_content="foo"), 0.0)]
|
||||
|
||||
|
||||
def test_pgvector_with_custom_connection() -> None:
|
||||
"""Test construction using a custom connection."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
engine = sqlalchemy.create_engine(CONNECTION_STRING)
|
||||
with engine.connect() as connection:
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
connection=connection,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_pgvector_with_custom_engine_args() -> None:
|
||||
"""Test construction using custom engine arguments."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
engine_args = {
|
||||
"pool_size": 5,
|
||||
"max_overflow": 10,
|
||||
"pool_recycle": -1,
|
||||
"pool_use_lifo": False,
|
||||
"pool_pre_ping": False,
|
||||
"pool_timeout": 30,
|
||||
}
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
engine_args=engine_args,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
@ -0,0 +1,73 @@
|
||||
"""Test PGVector functionality."""
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.embeddings import FakeEmbeddings
|
||||
from langchain_community.vectorstores import pgvector
|
||||
|
||||
_CONNECTION_STRING = pgvector.PGVector.connection_string_from_db_params(
|
||||
driver="psycopg2",
|
||||
host="localhost",
|
||||
port=5432,
|
||||
database="postgres",
|
||||
user="postgres",
|
||||
password="postgres",
|
||||
)
|
||||
|
||||
_EMBEDDING_FUNCTION = FakeEmbeddings(size=1536)
|
||||
|
||||
|
||||
@pytest.mark.requires("pgvector")
|
||||
@mock.patch("sqlalchemy.create_engine")
|
||||
def test_given_a_connection_is_provided_then_no_engine_should_be_created(
|
||||
create_engine: Mock,
|
||||
) -> None:
|
||||
"""When a connection is provided then no engine should be created."""
|
||||
pgvector.PGVector(
|
||||
connection_string=_CONNECTION_STRING,
|
||||
embedding_function=_EMBEDDING_FUNCTION,
|
||||
connection=mock.MagicMock(),
|
||||
)
|
||||
create_engine.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.requires("pgvector")
|
||||
@mock.patch("sqlalchemy.create_engine")
|
||||
def test_given_no_connection_or_engine_args_provided_default_engine_should_be_used(
|
||||
create_engine: Mock,
|
||||
) -> None:
|
||||
"""When no connection or engine arguments are provided then the default configuration must be used.""" # noqa: E501
|
||||
pgvector.PGVector(
|
||||
connection_string=_CONNECTION_STRING,
|
||||
embedding_function=_EMBEDDING_FUNCTION,
|
||||
)
|
||||
create_engine.assert_called_with(
|
||||
url=_CONNECTION_STRING,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("pgvector")
|
||||
@mock.patch("sqlalchemy.create_engine")
|
||||
def test_given_engine_args_are_provided_then_they_should_be_used(
|
||||
create_engine: Mock,
|
||||
) -> None:
|
||||
"""When engine arguments are provided then they must be used to create the underlying engine.""" # noqa: E501
|
||||
engine_args = {
|
||||
"pool_size": 5,
|
||||
"max_overflow": 10,
|
||||
"pool_recycle": -1,
|
||||
"pool_use_lifo": False,
|
||||
"pool_pre_ping": False,
|
||||
"pool_timeout": 30,
|
||||
}
|
||||
pgvector.PGVector(
|
||||
connection_string=_CONNECTION_STRING,
|
||||
embedding_function=_EMBEDDING_FUNCTION,
|
||||
engine_args=engine_args,
|
||||
)
|
||||
create_engine.assert_called_with(
|
||||
url=_CONNECTION_STRING,
|
||||
**engine_args,
|
||||
)
|
Loading…
Reference in New Issue
Block a user