Harrison/pg vector move (#7580)

This commit is contained in:
Harrison Chase 2023-07-11 23:22:34 -07:00 committed by GitHub
parent 2667ddc686
commit 641fd74baa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 50 deletions

View File

@ -1,9 +1,50 @@
from typing import Optional, Tuple
import sqlalchemy import sqlalchemy
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects.postgresql import JSON, UUID from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import Session, relationship
from langchain.vectorstores.pgvector import BaseModel, CollectionStore from langchain.vectorstores.pgvector import BaseModel
class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel): class EmbeddingStore(BaseModel):

View File

@ -4,17 +4,30 @@ from __future__ import annotations
import enum import enum
import logging import logging
import uuid import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import sqlalchemy import sqlalchemy
from sqlalchemy.dialects.postgresql import JSON, UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Session, declarative_base, relationship from sqlalchemy.orm import Session, declarative_base
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
if TYPE_CHECKING:
from langchain.vectorstores._pgvector_data_models import CollectionStore
class DistanceStrategy(str, enum.Enum): class DistanceStrategy(str, enum.Enum):
"""Enumerator of the Distance strategies.""" """Enumerator of the Distance strategies."""
@ -37,45 +50,6 @@ class BaseModel(Base):
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class PGVector(VectorStore): class PGVector(VectorStore):
"""VectorStore implementation using Postgres and pgvector. """VectorStore implementation using Postgres and pgvector.
@ -141,8 +115,12 @@ class PGVector(VectorStore):
""" """
self._conn = self.connect() self._conn = self.connect()
# self.create_vector_extension() # self.create_vector_extension()
from langchain.vectorstores._pgvector_data_models import EmbeddingStore from langchain.vectorstores._pgvector_data_models import (
CollectionStore,
EmbeddingStore,
)
self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore self.EmbeddingStore = EmbeddingStore
self.create_tables_if_not_exists() self.create_tables_if_not_exists()
self.create_collection() self.create_collection()
@ -173,7 +151,7 @@ class PGVector(VectorStore):
if self.pre_delete_collection: if self.pre_delete_collection:
self.delete_collection() self.delete_collection()
with Session(self._conn) as session: with Session(self._conn) as session:
CollectionStore.get_or_create( self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata session, self.collection_name, cmetadata=self.collection_metadata
) )
@ -188,7 +166,7 @@ class PGVector(VectorStore):
session.commit() session.commit()
def get_collection(self, session: Session) -> Optional["CollectionStore"]: def get_collection(self, session: Session) -> Optional["CollectionStore"]:
return CollectionStore.get_by_name(session, self.collection_name) return self.CollectionStore.get_by_name(session, self.collection_name)
@classmethod @classmethod
def __from( def __from(
@ -200,6 +178,7 @@ class PGVector(VectorStore):
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
connection_string: Optional[str] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
**kwargs: Any, **kwargs: Any,
) -> PGVector: ) -> PGVector:
@ -208,6 +187,7 @@ class PGVector(VectorStore):
if not metadatas: if not metadatas:
metadatas = [{} for _ in texts] metadatas = [{} for _ in texts]
if connection_string is None:
connection_string = cls.get_connection_string(kwargs) connection_string = cls.get_connection_string(kwargs)
store = cls( store = cls(
@ -389,8 +369,8 @@ class PGVector(VectorStore):
.filter(filter_by) .filter(filter_by)
.order_by(sqlalchemy.asc("distance")) .order_by(sqlalchemy.asc("distance"))
.join( .join(
CollectionStore, self.CollectionStore,
self.EmbeddingStore.collection_id == CollectionStore.uuid, self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
) )
.limit(k) .limit(k)
.all() .all()