mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Add connection args for pgvector vector store (#11930)
- **Description:** sqlalchemy create_engine() does not take into account connect_args which are mandatory for managed PGSQL instances on cloud providers (ssl_context for example). Also re-enabled create_vector_extension at post_init for using pgvector class seamlessly - **Tag maintainer:** @baskaryan, @eyurtsev, @hwchase17. --------- Co-authored-by: Sami Bargaoui <bargaoui.sam@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
4d6243fa87
commit
66f8cb015d
@ -84,6 +84,7 @@ class PGVector(VectorStore):
|
||||
distance_strategy: The distance strategy to use. (default: COSINE)
|
||||
pre_delete_collection: If True, will delete the collection if it exists.
|
||||
(default: False). Useful for testing.
|
||||
engine_args: SQLAlchemy's create engine arguments.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -114,6 +115,8 @@ class PGVector(VectorStore):
|
||||
pre_delete_collection: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
*,
|
||||
engine_args: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.connection_string = connection_string
|
||||
self.embedding_function = embedding_function
|
||||
@ -123,6 +126,7 @@ class PGVector(VectorStore):
|
||||
self.pre_delete_collection = pre_delete_collection
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self.engine_args = engine_args or {}
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(
|
||||
@ -132,7 +136,7 @@ class PGVector(VectorStore):
|
||||
Initialize the store.
|
||||
"""
|
||||
self._conn = self.connect()
|
||||
# self.create_vector_extension()
|
||||
self.create_vector_extension()
|
||||
from langchain.vectorstores._pgvector_data_models import (
|
||||
CollectionStore,
|
||||
EmbeddingStore,
|
||||
@ -148,7 +152,7 @@ class PGVector(VectorStore):
|
||||
return self.embedding_function
|
||||
|
||||
def connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string)
|
||||
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
||||
conn = engine.connect()
|
||||
return conn
|
||||
|
||||
@ -159,7 +163,7 @@ class PGVector(VectorStore):
|
||||
session.execute(statement)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
raise Exception(f"Failed to create vector extension: {e}") from e
|
||||
|
||||
def create_tables_if_not_exists(self) -> None:
|
||||
with self._conn.begin():
|
||||
|
Loading…
Reference in New Issue
Block a user