diff --git a/langchain/vectorstores/_pgvector_data_models.py b/langchain/vectorstores/_pgvector_data_models.py index 1be27e5533f..f44bd2e3527 100644 --- a/langchain/vectorstores/_pgvector_data_models.py +++ b/langchain/vectorstores/_pgvector_data_models.py @@ -1,9 +1,50 @@ +from typing import Optional, Tuple + import sqlalchemy from pgvector.sqlalchemy import Vector from sqlalchemy.dialects.postgresql import JSON, UUID -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session, relationship -from langchain.vectorstores.pgvector import BaseModel, CollectionStore +from langchain.vectorstores.pgvector import BaseModel + + +class CollectionStore(BaseModel): + __tablename__ = "langchain_pg_collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata = sqlalchemy.Column(JSON) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: + return session.query(cls).filter(cls.name == name).first() # type: ignore + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created class EmbeddingStore(BaseModel): diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index e3fcb3f04a5..b0958b78fa5 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -4,17 +4,30 @@ from __future__ import annotations import enum import logging import uuid -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) import sqlalchemy -from sqlalchemy.dialects.postgresql import JSON, UUID -from sqlalchemy.orm import Session, declarative_base, relationship +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Session, declarative_base from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore +if TYPE_CHECKING: + from langchain.vectorstores._pgvector_data_models import CollectionStore + class DistanceStrategy(str, enum.Enum): """Enumerator of the Distance strategies.""" @@ -37,45 +50,6 @@ class BaseModel(Base): uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) -class CollectionStore(BaseModel): - __tablename__ = "langchain_pg_collection" - - name = sqlalchemy.Column(sqlalchemy.String) - cmetadata = sqlalchemy.Column(JSON) - - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - - @classmethod - def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: - return session.query(cls).filter(cls.name == name).first() # type: ignore - - @classmethod - def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, - ) -> Tuple["CollectionStore", bool]: - """ - Get or create a collection. - Returns [Collection, bool] where the bool is True if the collection was created. - """ - created = False - collection = cls.get_by_name(session, name) - if collection: - return collection, created - - collection = cls(name=name, cmetadata=cmetadata) - session.add(collection) - session.commit() - created = True - return collection, created - - class PGVector(VectorStore): """VectorStore implementation using Postgres and pgvector. @@ -141,8 +115,12 @@ class PGVector(VectorStore): """ self._conn = self.connect() # self.create_vector_extension() - from langchain.vectorstores._pgvector_data_models import EmbeddingStore + from langchain.vectorstores._pgvector_data_models import ( + CollectionStore, + EmbeddingStore, + ) + self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore self.create_tables_if_not_exists() self.create_collection() @@ -173,7 +151,7 @@ class PGVector(VectorStore): if self.pre_delete_collection: self.delete_collection() with Session(self._conn) as session: - CollectionStore.get_or_create( + self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) @@ -188,7 +166,7 @@ class PGVector(VectorStore): session.commit() def get_collection(self, session: Session) -> Optional["CollectionStore"]: - return CollectionStore.get_by_name(session, self.collection_name) + return self.CollectionStore.get_by_name(session, self.collection_name) @classmethod def __from( @@ -200,6 +178,7 @@ class PGVector(VectorStore): ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection_string: Optional[str] = None, pre_delete_collection: bool = False, **kwargs: Any, ) -> PGVector: @@ -208,7 +187,8 @@ class PGVector(VectorStore): if not metadatas: metadatas = [{} for _ in texts] - connection_string = cls.get_connection_string(kwargs) + if connection_string is None: + connection_string = cls.get_connection_string(kwargs) store = cls( connection_string=connection_string, @@ -389,8 +369,8 @@ class PGVector(VectorStore): .filter(filter_by) .order_by(sqlalchemy.asc("distance")) .join( - CollectionStore, - self.EmbeddingStore.collection_id == CollectionStore.uuid, + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, ) .limit(k) .all()