mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
Harrison/pg vector move (#7580)
This commit is contained in:
parent
2667ddc686
commit
641fd74baa
@ -1,9 +1,50 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import Session, relationship
|
||||||
|
|
||||||
from langchain.vectorstores.pgvector import BaseModel, CollectionStore
|
from langchain.vectorstores.pgvector import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionStore(BaseModel):
|
||||||
|
__tablename__ = "langchain_pg_collection"
|
||||||
|
|
||||||
|
name = sqlalchemy.Column(sqlalchemy.String)
|
||||||
|
cmetadata = sqlalchemy.Column(JSON)
|
||||||
|
|
||||||
|
embeddings = relationship(
|
||||||
|
"EmbeddingStore",
|
||||||
|
back_populates="collection",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
|
||||||
|
return session.query(cls).filter(cls.name == name).first() # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_or_create(
|
||||||
|
cls,
|
||||||
|
session: Session,
|
||||||
|
name: str,
|
||||||
|
cmetadata: Optional[dict] = None,
|
||||||
|
) -> Tuple["CollectionStore", bool]:
|
||||||
|
"""
|
||||||
|
Get or create a collection.
|
||||||
|
Returns [Collection, bool] where the bool is True if the collection was created.
|
||||||
|
"""
|
||||||
|
created = False
|
||||||
|
collection = cls.get_by_name(session, name)
|
||||||
|
if collection:
|
||||||
|
return collection, created
|
||||||
|
|
||||||
|
collection = cls(name=name, cmetadata=cmetadata)
|
||||||
|
session.add(collection)
|
||||||
|
session.commit()
|
||||||
|
created = True
|
||||||
|
return collection, created
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingStore(BaseModel):
|
class EmbeddingStore(BaseModel):
|
||||||
|
@ -4,17 +4,30 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import Session, declarative_base, relationship
|
from sqlalchemy.orm import Session, declarative_base
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.vectorstores._pgvector_data_models import CollectionStore
|
||||||
|
|
||||||
|
|
||||||
class DistanceStrategy(str, enum.Enum):
|
class DistanceStrategy(str, enum.Enum):
|
||||||
"""Enumerator of the Distance strategies."""
|
"""Enumerator of the Distance strategies."""
|
||||||
@ -37,45 +50,6 @@ class BaseModel(Base):
|
|||||||
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
|
||||||
|
|
||||||
class CollectionStore(BaseModel):
|
|
||||||
__tablename__ = "langchain_pg_collection"
|
|
||||||
|
|
||||||
name = sqlalchemy.Column(sqlalchemy.String)
|
|
||||||
cmetadata = sqlalchemy.Column(JSON)
|
|
||||||
|
|
||||||
embeddings = relationship(
|
|
||||||
"EmbeddingStore",
|
|
||||||
back_populates="collection",
|
|
||||||
passive_deletes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
|
|
||||||
return session.query(cls).filter(cls.name == name).first() # type: ignore
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_or_create(
|
|
||||||
cls,
|
|
||||||
session: Session,
|
|
||||||
name: str,
|
|
||||||
cmetadata: Optional[dict] = None,
|
|
||||||
) -> Tuple["CollectionStore", bool]:
|
|
||||||
"""
|
|
||||||
Get or create a collection.
|
|
||||||
Returns [Collection, bool] where the bool is True if the collection was created.
|
|
||||||
"""
|
|
||||||
created = False
|
|
||||||
collection = cls.get_by_name(session, name)
|
|
||||||
if collection:
|
|
||||||
return collection, created
|
|
||||||
|
|
||||||
collection = cls(name=name, cmetadata=cmetadata)
|
|
||||||
session.add(collection)
|
|
||||||
session.commit()
|
|
||||||
created = True
|
|
||||||
return collection, created
|
|
||||||
|
|
||||||
|
|
||||||
class PGVector(VectorStore):
|
class PGVector(VectorStore):
|
||||||
"""VectorStore implementation using Postgres and pgvector.
|
"""VectorStore implementation using Postgres and pgvector.
|
||||||
|
|
||||||
@ -141,8 +115,12 @@ class PGVector(VectorStore):
|
|||||||
"""
|
"""
|
||||||
self._conn = self.connect()
|
self._conn = self.connect()
|
||||||
# self.create_vector_extension()
|
# self.create_vector_extension()
|
||||||
from langchain.vectorstores._pgvector_data_models import EmbeddingStore
|
from langchain.vectorstores._pgvector_data_models import (
|
||||||
|
CollectionStore,
|
||||||
|
EmbeddingStore,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.CollectionStore = CollectionStore
|
||||||
self.EmbeddingStore = EmbeddingStore
|
self.EmbeddingStore = EmbeddingStore
|
||||||
self.create_tables_if_not_exists()
|
self.create_tables_if_not_exists()
|
||||||
self.create_collection()
|
self.create_collection()
|
||||||
@ -173,7 +151,7 @@ class PGVector(VectorStore):
|
|||||||
if self.pre_delete_collection:
|
if self.pre_delete_collection:
|
||||||
self.delete_collection()
|
self.delete_collection()
|
||||||
with Session(self._conn) as session:
|
with Session(self._conn) as session:
|
||||||
CollectionStore.get_or_create(
|
self.CollectionStore.get_or_create(
|
||||||
session, self.collection_name, cmetadata=self.collection_metadata
|
session, self.collection_name, cmetadata=self.collection_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -188,7 +166,7 @@ class PGVector(VectorStore):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
|
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
|
||||||
return CollectionStore.get_by_name(session, self.collection_name)
|
return self.CollectionStore.get_by_name(session, self.collection_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __from(
|
def __from(
|
||||||
@ -200,6 +178,7 @@ class PGVector(VectorStore):
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||||
|
connection_string: Optional[str] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> PGVector:
|
) -> PGVector:
|
||||||
@ -208,7 +187,8 @@ class PGVector(VectorStore):
|
|||||||
|
|
||||||
if not metadatas:
|
if not metadatas:
|
||||||
metadatas = [{} for _ in texts]
|
metadatas = [{} for _ in texts]
|
||||||
connection_string = cls.get_connection_string(kwargs)
|
if connection_string is None:
|
||||||
|
connection_string = cls.get_connection_string(kwargs)
|
||||||
|
|
||||||
store = cls(
|
store = cls(
|
||||||
connection_string=connection_string,
|
connection_string=connection_string,
|
||||||
@ -389,8 +369,8 @@ class PGVector(VectorStore):
|
|||||||
.filter(filter_by)
|
.filter(filter_by)
|
||||||
.order_by(sqlalchemy.asc("distance"))
|
.order_by(sqlalchemy.asc("distance"))
|
||||||
.join(
|
.join(
|
||||||
CollectionStore,
|
self.CollectionStore,
|
||||||
self.EmbeddingStore.collection_id == CollectionStore.uuid,
|
self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
|
||||||
)
|
)
|
||||||
.limit(k)
|
.limit(k)
|
||||||
.all()
|
.all()
|
||||||
|
Loading…
Reference in New Issue
Block a user