mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
Compare commits
5 Commits
langchain-
...
dev2049/pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2236d1deec | ||
|
|
ce5ee49bd2 | ||
|
|
a40a59a83c | ||
|
|
0d3f8a3c61 | ||
|
|
97abe337ba |
@@ -24,6 +24,10 @@ To import this vectorstore:
|
|||||||
from langchain.vectorstores.pgvector import PGVector
|
from langchain.vectorstores.pgvector import PGVector
|
||||||
```
|
```
|
||||||
|
|
||||||
|
PGVector embedding size is not autodetected. If you are using ChatGPT or any other embedding with 1536 dimensions
|
||||||
|
default is fine. If you are going to use for example HuggingFaceEmbeddings you need to set the environment variable named `PGVECTOR_VECTOR_SIZE`
|
||||||
|
to the needed value, In case of HuggingFaceEmbeddings is would be: `PGVECTOR_VECTOR_SIZE=768`
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
For a more detailed walkthrough of the PGVector Wrapper, see [this notebook](../modules/indexes/vectorstores/examples/pgvector.ipynb)
|
For a more detailed walkthrough of the PGVector Wrapper, see [this notebook](../modules/indexes/vectorstores/examples/pgvector.ipynb)
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
"""VectorStore wrapper around a Postgres/PGVector database."""
|
"""VectorStore wrapper around a Postgres/PGVector database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Type
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pgvector.sqlalchemy import Vector
|
|
||||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||||
from sqlalchemy.orm import Session, declarative_base, relationship
|
from sqlalchemy.orm import Session, declarative_base, declared_attr, relationship
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@@ -19,8 +17,10 @@ from langchain.vectorstores.base import VectorStore
|
|||||||
Base = declarative_base() # type: Any
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
|
||||||
ADA_TOKEN_COUNT = 1536
|
PGVECTOR_VECTOR_SIZE = 1536
|
||||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||||
|
DEFAULT_DISTANCE_STRATEGY = "cosine"
|
||||||
|
DistanceStrategy = Literal["cosine", "euclidean", "max_inner_product"]
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(Base):
|
class BaseModel(Base):
|
||||||
@@ -67,7 +67,7 @@ class CollectionStore(BaseModel):
|
|||||||
return collection, created
|
return collection, created
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingStore(BaseModel):
|
class BaseEmbeddingStore(BaseModel):
|
||||||
__tablename__ = "langchain_pg_embedding"
|
__tablename__ = "langchain_pg_embedding"
|
||||||
|
|
||||||
collection_id = sqlalchemy.Column(
|
collection_id = sqlalchemy.Column(
|
||||||
@@ -79,7 +79,6 @@ class EmbeddingStore(BaseModel):
|
|||||||
)
|
)
|
||||||
collection = relationship(CollectionStore, back_populates="embeddings")
|
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||||
|
|
||||||
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
|
|
||||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||||
|
|
||||||
@@ -92,32 +91,8 @@ class QueryResult:
|
|||||||
distance: float
|
distance: float
|
||||||
|
|
||||||
|
|
||||||
class DistanceStrategy(str, enum.Enum):
|
|
||||||
EUCLIDEAN = EmbeddingStore.embedding.l2_distance
|
|
||||||
COSINE = EmbeddingStore.embedding.cosine_distance
|
|
||||||
MAX_INNER_PRODUCT = EmbeddingStore.embedding.max_inner_product
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
|
|
||||||
|
|
||||||
|
|
||||||
class PGVector(VectorStore):
|
class PGVector(VectorStore):
|
||||||
"""
|
"""VectorStore implementation using Postgres and pgvector."""
|
||||||
VectorStore implementation using Postgres and pgvector.
|
|
||||||
- `connection_string` is a postgres connection string.
|
|
||||||
- `embedding_function` any embedding function implementing
|
|
||||||
`langchain.embeddings.base.Embeddings` interface.
|
|
||||||
- `collection_name` is the name of the collection to use. (default: langchain)
|
|
||||||
- NOTE: This is not the name of the table, but the name of the collection.
|
|
||||||
The tables will be created when initializing the store (if not exists)
|
|
||||||
So, make sure the user has the right permissions to create tables.
|
|
||||||
- `distance_strategy` is the distance strategy to use. (default: EUCLIDEAN)
|
|
||||||
- `EUCLIDEAN` is the euclidean distance.
|
|
||||||
- `COSINE` is the cosine distance.
|
|
||||||
- `pre_delete_collection` if True, will delete the collection if it exists.
|
|
||||||
(default: False)
|
|
||||||
- Useful for testing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -128,7 +103,33 @@ class PGVector(VectorStore):
|
|||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
logger: Optional[logging.Logger] = None,
|
logger: Optional[logging.Logger] = None,
|
||||||
|
vector_size: Optional[int] = PGVECTOR_VECTOR_SIZE,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
connection_string: is a postgres connection string.
|
||||||
|
embedding_function: text embedding model.
|
||||||
|
collection_name: is the name of the collection to use. (default: langchain)
|
||||||
|
NOTE: This is not the name of the table, but the name of the collection.
|
||||||
|
The tables will be created when initializing the store
|
||||||
|
(if not exists). So, make sure the user has the right permissions
|
||||||
|
to create tables.
|
||||||
|
distance_strategy: is the distance strategy to use. (default: "cosine")
|
||||||
|
- `cosine` is the cosine distance.
|
||||||
|
- `euclidean` is the euclidean distance.
|
||||||
|
- `max_inner_product` is the max inner product.
|
||||||
|
pre_delete_collection: if True, will delete the collection if it exists.
|
||||||
|
Useful for testing.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import pgvector, please install it with `pip install "
|
||||||
|
"pgvector`."
|
||||||
|
) from e
|
||||||
self.connection_string = connection_string
|
self.connection_string = connection_string
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
@@ -136,6 +137,16 @@ class PGVector(VectorStore):
|
|||||||
self.distance_strategy = distance_strategy
|
self.distance_strategy = distance_strategy
|
||||||
self.pre_delete_collection = pre_delete_collection
|
self.pre_delete_collection = pre_delete_collection
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
vector_size = vector_size or len(embedding_function.embed_documents(["foo"])[0])
|
||||||
|
|
||||||
|
class EmbeddingStore(BaseEmbeddingStore):
|
||||||
|
@declared_attr
|
||||||
|
def embedding(cls):
|
||||||
|
return BaseEmbeddingStore.__table__.c.get(
|
||||||
|
"embedding", sqlalchemy.Column(Vector(vector_size))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.EmbeddingStore = EmbeddingStore
|
||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def __post_init__(
|
def __post_init__(
|
||||||
@@ -145,10 +156,17 @@ class PGVector(VectorStore):
|
|||||||
Initialize the store.
|
Initialize the store.
|
||||||
"""
|
"""
|
||||||
self._conn = self.connect()
|
self._conn = self.connect()
|
||||||
# self.create_vector_extension()
|
|
||||||
self.create_tables_if_not_exists()
|
self.create_tables_if_not_exists()
|
||||||
self.create_collection()
|
self.create_collection()
|
||||||
|
|
||||||
|
def _distance_fn(self) -> Callable:
|
||||||
|
_map = {
|
||||||
|
"euclidean": self.EmbeddingStore.embedding.l2_distance,
|
||||||
|
"cosine": self.EmbeddingStore.embedding.cosine_distance,
|
||||||
|
"max_inner_product": self.EmbeddingStore.embedding.max_inner_product,
|
||||||
|
}
|
||||||
|
return _map[self.distance_strategy]
|
||||||
|
|
||||||
def connect(self) -> sqlalchemy.engine.Connection:
|
def connect(self) -> sqlalchemy.engine.Connection:
|
||||||
engine = sqlalchemy.create_engine(self.connection_string)
|
engine = sqlalchemy.create_engine(self.connection_string)
|
||||||
conn = engine.connect()
|
conn = engine.connect()
|
||||||
@@ -201,7 +219,7 @@ class PGVector(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
distance_strategy: DistanceStrategy = "cosine",
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> PGVector:
|
) -> PGVector:
|
||||||
@@ -248,7 +266,7 @@ class PGVector(VectorStore):
|
|||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||||
embedding_store = EmbeddingStore(
|
embedding_store = self.EmbeddingStore(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
document=text,
|
document=text,
|
||||||
cmetadata=metadata,
|
cmetadata=metadata,
|
||||||
@@ -288,7 +306,7 @@ class PGVector(VectorStore):
|
|||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||||
embedding_store = EmbeddingStore(
|
embedding_store = self.EmbeddingStore(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
document=text,
|
document=text,
|
||||||
cmetadata=metadata,
|
cmetadata=metadata,
|
||||||
@@ -357,7 +375,7 @@ class PGVector(VectorStore):
|
|||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError("Collection not found")
|
raise ValueError("Collection not found")
|
||||||
|
|
||||||
filter_by = EmbeddingStore.collection_id == collection.uuid
|
filter_by = self.EmbeddingStore.collection_id == collection.uuid
|
||||||
|
|
||||||
if filter is not None:
|
if filter is not None:
|
||||||
filter_clauses = []
|
filter_clauses = []
|
||||||
@@ -367,12 +385,12 @@ class PGVector(VectorStore):
|
|||||||
value_case_insensitive = {
|
value_case_insensitive = {
|
||||||
k.lower(): v for k, v in value.items()
|
k.lower(): v for k, v in value.items()
|
||||||
}
|
}
|
||||||
filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_(
|
filter_by_metadata = self.EmbeddingStore.cmetadata[
|
||||||
value_case_insensitive[IN]
|
key
|
||||||
)
|
].astext.in_(value_case_insensitive[IN])
|
||||||
filter_clauses.append(filter_by_metadata)
|
filter_clauses.append(filter_by_metadata)
|
||||||
else:
|
else:
|
||||||
filter_by_metadata = EmbeddingStore.cmetadata[
|
filter_by_metadata = self.EmbeddingStore.cmetadata[
|
||||||
key
|
key
|
||||||
].astext == str(value)
|
].astext == str(value)
|
||||||
filter_clauses.append(filter_by_metadata)
|
filter_clauses.append(filter_by_metadata)
|
||||||
@@ -381,14 +399,14 @@ class PGVector(VectorStore):
|
|||||||
|
|
||||||
results: List[QueryResult] = (
|
results: List[QueryResult] = (
|
||||||
session.query(
|
session.query(
|
||||||
EmbeddingStore,
|
self.EmbeddingStore,
|
||||||
self.distance_strategy(embedding).label("distance"), # type: ignore
|
self._distance_fn(embedding).label("distance"), # type: ignore
|
||||||
)
|
)
|
||||||
.filter(filter_by)
|
.filter(filter_by)
|
||||||
.order_by(sqlalchemy.asc("distance"))
|
.order_by(sqlalchemy.asc("distance"))
|
||||||
.join(
|
.join(
|
||||||
CollectionStore,
|
CollectionStore,
|
||||||
EmbeddingStore.collection_id == CollectionStore.uuid,
|
self.EmbeddingStore.collection_id == CollectionStore.uuid,
|
||||||
)
|
)
|
||||||
.limit(k)
|
.limit(k)
|
||||||
.all()
|
.all()
|
||||||
@@ -435,7 +453,7 @@ class PGVector(VectorStore):
|
|||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
distance_strategy: DistanceStrategy = "cosine",
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@@ -467,7 +485,7 @@ class PGVector(VectorStore):
|
|||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
distance_strategy: DistanceStrategy = "cosine",
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@@ -528,15 +546,14 @@ class PGVector(VectorStore):
|
|||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = "cosine",
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> PGVector:
|
) -> PGVector:
|
||||||
"""
|
"""
|
||||||
Return VectorStore initialized from documents and embeddings.
|
Return VectorStore initialized from documents and embeddings.
|
||||||
Postgres connection string is required
|
Postgres connection string is required. Either pass it as a parameter
|
||||||
"Either pass it as a parameter
|
|
||||||
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -547,9 +564,9 @@ class PGVector(VectorStore):
|
|||||||
kwargs["connection_string"] = connection_string
|
kwargs["connection_string"] = connection_string
|
||||||
|
|
||||||
return cls.from_texts(
|
return cls.from_texts(
|
||||||
texts=texts,
|
texts,
|
||||||
|
embedding,
|
||||||
pre_delete_collection=pre_delete_collection,
|
pre_delete_collection=pre_delete_collection,
|
||||||
embedding=embedding,
|
|
||||||
distance_strategy=distance_strategy,
|
distance_strategy=distance_strategy,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
|
|||||||
24
tests/unit_tests/vectorstores/test_pgvector.py
Normal file
24
tests/unit_tests/vectorstores/test_pgvector.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.vectorstores.pgvector import PGVECTOR_VECTOR_SIZE, PGVector
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("pgvector")
|
||||||
|
def test_embedding_store_init_defaults() -> None:
|
||||||
|
expected = PGVECTOR_VECTOR_SIZE
|
||||||
|
actual = PGVector(
|
||||||
|
"postgresql+psycopg2://admin:admin@localhost:5432/mydatabase", FakeEmbeddings()
|
||||||
|
).EmbeddingStore.embedding.type.dim
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("pgvector")
|
||||||
|
def test_embedding_store_init_vector_size() -> None:
|
||||||
|
expected = 2
|
||||||
|
actual = PGVector(
|
||||||
|
"postgresql+psycopg2://admin:admin@localhost:5432/mydatabase",
|
||||||
|
FakeEmbeddings(),
|
||||||
|
vector_size=2,
|
||||||
|
).EmbeddingStore.embedding.type.dim
|
||||||
|
assert expected == actual
|
||||||
Reference in New Issue
Block a user