mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
Add connection args for pgvector vector store (#11930)
- **Description:** sqlalchemy create_engine() does not take into account connect_args which are mandatory for managed PGSQL instances on cloud providers (ssl_context for example). Also re-enabled create_vector_extension at post_init for using pgvector class seamlessly - **Tag maintainer:** @baskaryan, @eyurtsev, @hwchase17. --------- Co-authored-by: Sami Bargaoui <bargaoui.sam@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
4d6243fa87
commit
66f8cb015d
@ -84,6 +84,7 @@ class PGVector(VectorStore):
|
|||||||
distance_strategy: The distance strategy to use. (default: COSINE)
|
distance_strategy: The distance strategy to use. (default: COSINE)
|
||||||
pre_delete_collection: If True, will delete the collection if it exists.
|
pre_delete_collection: If True, will delete the collection if it exists.
|
||||||
(default: False). Useful for testing.
|
(default: False). Useful for testing.
|
||||||
|
engine_args: SQLAlchemy's create engine arguments.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -114,6 +115,8 @@ class PGVector(VectorStore):
|
|||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
logger: Optional[logging.Logger] = None,
|
logger: Optional[logging.Logger] = None,
|
||||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||||
|
*,
|
||||||
|
engine_args: Optional[dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.connection_string = connection_string
|
self.connection_string = connection_string
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
@ -123,6 +126,7 @@ class PGVector(VectorStore):
|
|||||||
self.pre_delete_collection = pre_delete_collection
|
self.pre_delete_collection = pre_delete_collection
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
self.override_relevance_score_fn = relevance_score_fn
|
self.override_relevance_score_fn = relevance_score_fn
|
||||||
|
self.engine_args = engine_args or {}
|
||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def __post_init__(
|
def __post_init__(
|
||||||
@ -132,7 +136,7 @@ class PGVector(VectorStore):
|
|||||||
Initialize the store.
|
Initialize the store.
|
||||||
"""
|
"""
|
||||||
self._conn = self.connect()
|
self._conn = self.connect()
|
||||||
# self.create_vector_extension()
|
self.create_vector_extension()
|
||||||
from langchain.vectorstores._pgvector_data_models import (
|
from langchain.vectorstores._pgvector_data_models import (
|
||||||
CollectionStore,
|
CollectionStore,
|
||||||
EmbeddingStore,
|
EmbeddingStore,
|
||||||
@ -148,7 +152,7 @@ class PGVector(VectorStore):
|
|||||||
return self.embedding_function
|
return self.embedding_function
|
||||||
|
|
||||||
def connect(self) -> sqlalchemy.engine.Connection:
|
def connect(self) -> sqlalchemy.engine.Connection:
|
||||||
engine = sqlalchemy.create_engine(self.connection_string)
|
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
||||||
conn = engine.connect()
|
conn = engine.connect()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@ -159,7 +163,7 @@ class PGVector(VectorStore):
|
|||||||
session.execute(statement)
|
session.execute(statement)
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(e)
|
raise Exception(f"Failed to create vector extension: {e}") from e
|
||||||
|
|
||||||
def create_tables_if_not_exists(self) -> None:
|
def create_tables_if_not_exists(self) -> None:
|
||||||
with self._conn.begin():
|
with self._conn.begin():
|
||||||
|
Loading…
Reference in New Issue
Block a user