mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-16 01:59:52 +00:00
Compare commits
5 Commits
langchain-
...
dev2049/pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2236d1deec | ||
|
|
ce5ee49bd2 | ||
|
|
a40a59a83c | ||
|
|
0d3f8a3c61 | ||
|
|
97abe337ba |
@@ -24,6 +24,10 @@ To import this vectorstore:
|
||||
from langchain.vectorstores.pgvector import PGVector
|
||||
```
|
||||
|
||||
PGVector embedding size is not autodetected. If you are using ChatGPT or any other embedding with 1536 dimensions
|
||||
default is fine. If you are going to use for example HuggingFaceEmbeddings you need to set the environment variable named `PGVECTOR_VECTOR_SIZE`
|
||||
to the needed value, In case of HuggingFaceEmbeddings is would be: `PGVECTOR_VECTOR_SIZE=768`
|
||||
|
||||
### Usage
|
||||
|
||||
For a more detailed walkthrough of the PGVector Wrapper, see [this notebook](../modules/indexes/vectorstores/examples/pgvector.ipynb)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"""VectorStore wrapper around a Postgres/PGVector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import sqlalchemy
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||
from sqlalchemy.orm import Session, declarative_base, relationship
|
||||
from sqlalchemy.orm import Session, declarative_base, declared_attr, relationship
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -19,8 +17,10 @@ from langchain.vectorstores.base import VectorStore
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
ADA_TOKEN_COUNT = 1536
|
||||
PGVECTOR_VECTOR_SIZE = 1536
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
DEFAULT_DISTANCE_STRATEGY = "cosine"
|
||||
DistanceStrategy = Literal["cosine", "euclidean", "max_inner_product"]
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
@@ -67,7 +67,7 @@ class CollectionStore(BaseModel):
|
||||
return collection, created
|
||||
|
||||
|
||||
class EmbeddingStore(BaseModel):
|
||||
class BaseEmbeddingStore(BaseModel):
|
||||
__tablename__ = "langchain_pg_embedding"
|
||||
|
||||
collection_id = sqlalchemy.Column(
|
||||
@@ -79,7 +79,6 @@ class EmbeddingStore(BaseModel):
|
||||
)
|
||||
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||
|
||||
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
|
||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||
|
||||
@@ -92,32 +91,8 @@ class QueryResult:
|
||||
distance: float
|
||||
|
||||
|
||||
class DistanceStrategy(str, enum.Enum):
|
||||
EUCLIDEAN = EmbeddingStore.embedding.l2_distance
|
||||
COSINE = EmbeddingStore.embedding.cosine_distance
|
||||
MAX_INNER_PRODUCT = EmbeddingStore.embedding.max_inner_product
|
||||
|
||||
|
||||
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
|
||||
|
||||
|
||||
class PGVector(VectorStore):
|
||||
"""
|
||||
VectorStore implementation using Postgres and pgvector.
|
||||
- `connection_string` is a postgres connection string.
|
||||
- `embedding_function` any embedding function implementing
|
||||
`langchain.embeddings.base.Embeddings` interface.
|
||||
- `collection_name` is the name of the collection to use. (default: langchain)
|
||||
- NOTE: This is not the name of the table, but the name of the collection.
|
||||
The tables will be created when initializing the store (if not exists)
|
||||
So, make sure the user has the right permissions to create tables.
|
||||
- `distance_strategy` is the distance strategy to use. (default: EUCLIDEAN)
|
||||
- `EUCLIDEAN` is the euclidean distance.
|
||||
- `COSINE` is the cosine distance.
|
||||
- `pre_delete_collection` if True, will delete the collection if it exists.
|
||||
(default: False)
|
||||
- Useful for testing.
|
||||
"""
|
||||
"""VectorStore implementation using Postgres and pgvector."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -128,7 +103,33 @@ class PGVector(VectorStore):
|
||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||
pre_delete_collection: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
vector_size: Optional[int] = PGVECTOR_VECTOR_SIZE,
|
||||
) -> None:
|
||||
"""Initialize vector store.
|
||||
|
||||
Args:
|
||||
|
||||
connection_string: is a postgres connection string.
|
||||
embedding_function: text embedding model.
|
||||
collection_name: is the name of the collection to use. (default: langchain)
|
||||
NOTE: This is not the name of the table, but the name of the collection.
|
||||
The tables will be created when initializing the store
|
||||
(if not exists). So, make sure the user has the right permissions
|
||||
to create tables.
|
||||
distance_strategy: is the distance strategy to use. (default: "cosine")
|
||||
- `cosine` is the cosine distance.
|
||||
- `euclidean` is the euclidean distance.
|
||||
- `max_inner_product` is the max inner product.
|
||||
pre_delete_collection: if True, will delete the collection if it exists.
|
||||
Useful for testing.
|
||||
"""
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import pgvector, please install it with `pip install "
|
||||
"pgvector`."
|
||||
) from e
|
||||
self.connection_string = connection_string
|
||||
self.embedding_function = embedding_function
|
||||
self.collection_name = collection_name
|
||||
@@ -136,6 +137,16 @@ class PGVector(VectorStore):
|
||||
self.distance_strategy = distance_strategy
|
||||
self.pre_delete_collection = pre_delete_collection
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
vector_size = vector_size or len(embedding_function.embed_documents(["foo"])[0])
|
||||
|
||||
class EmbeddingStore(BaseEmbeddingStore):
|
||||
@declared_attr
|
||||
def embedding(cls):
|
||||
return BaseEmbeddingStore.__table__.c.get(
|
||||
"embedding", sqlalchemy.Column(Vector(vector_size))
|
||||
)
|
||||
|
||||
self.EmbeddingStore = EmbeddingStore
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(
|
||||
@@ -145,10 +156,17 @@ class PGVector(VectorStore):
|
||||
Initialize the store.
|
||||
"""
|
||||
self._conn = self.connect()
|
||||
# self.create_vector_extension()
|
||||
self.create_tables_if_not_exists()
|
||||
self.create_collection()
|
||||
|
||||
def _distance_fn(self) -> Callable:
|
||||
_map = {
|
||||
"euclidean": self.EmbeddingStore.embedding.l2_distance,
|
||||
"cosine": self.EmbeddingStore.embedding.cosine_distance,
|
||||
"max_inner_product": self.EmbeddingStore.embedding.max_inner_product,
|
||||
}
|
||||
return _map[self.distance_strategy]
|
||||
|
||||
def connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string)
|
||||
conn = engine.connect()
|
||||
@@ -201,7 +219,7 @@ class PGVector(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
||||
distance_strategy: DistanceStrategy = "cosine",
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> PGVector:
|
||||
@@ -248,7 +266,7 @@ class PGVector(VectorStore):
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||
embedding_store = EmbeddingStore(
|
||||
embedding_store = self.EmbeddingStore(
|
||||
embedding=embedding,
|
||||
document=text,
|
||||
cmetadata=metadata,
|
||||
@@ -288,7 +306,7 @@ class PGVector(VectorStore):
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||
embedding_store = EmbeddingStore(
|
||||
embedding_store = self.EmbeddingStore(
|
||||
embedding=embedding,
|
||||
document=text,
|
||||
cmetadata=metadata,
|
||||
@@ -357,7 +375,7 @@ class PGVector(VectorStore):
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
|
||||
filter_by = EmbeddingStore.collection_id == collection.uuid
|
||||
filter_by = self.EmbeddingStore.collection_id == collection.uuid
|
||||
|
||||
if filter is not None:
|
||||
filter_clauses = []
|
||||
@@ -367,12 +385,12 @@ class PGVector(VectorStore):
|
||||
value_case_insensitive = {
|
||||
k.lower(): v for k, v in value.items()
|
||||
}
|
||||
filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_(
|
||||
value_case_insensitive[IN]
|
||||
)
|
||||
filter_by_metadata = self.EmbeddingStore.cmetadata[
|
||||
key
|
||||
].astext.in_(value_case_insensitive[IN])
|
||||
filter_clauses.append(filter_by_metadata)
|
||||
else:
|
||||
filter_by_metadata = EmbeddingStore.cmetadata[
|
||||
filter_by_metadata = self.EmbeddingStore.cmetadata[
|
||||
key
|
||||
].astext == str(value)
|
||||
filter_clauses.append(filter_by_metadata)
|
||||
@@ -381,14 +399,14 @@ class PGVector(VectorStore):
|
||||
|
||||
results: List[QueryResult] = (
|
||||
session.query(
|
||||
EmbeddingStore,
|
||||
self.distance_strategy(embedding).label("distance"), # type: ignore
|
||||
self.EmbeddingStore,
|
||||
self._distance_fn(embedding).label("distance"), # type: ignore
|
||||
)
|
||||
.filter(filter_by)
|
||||
.order_by(sqlalchemy.asc("distance"))
|
||||
.join(
|
||||
CollectionStore,
|
||||
EmbeddingStore.collection_id == CollectionStore.uuid,
|
||||
self.EmbeddingStore.collection_id == CollectionStore.uuid,
|
||||
)
|
||||
.limit(k)
|
||||
.all()
|
||||
@@ -435,7 +453,7 @@ class PGVector(VectorStore):
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
||||
distance_strategy: DistanceStrategy = "cosine",
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -467,7 +485,7 @@ class PGVector(VectorStore):
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
||||
distance_strategy: DistanceStrategy = "cosine",
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -528,15 +546,14 @@ class PGVector(VectorStore):
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||
distance_strategy: DistanceStrategy = "cosine",
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> PGVector:
|
||||
"""
|
||||
Return VectorStore initialized from documents and embeddings.
|
||||
Postgres connection string is required
|
||||
"Either pass it as a parameter
|
||||
Postgres connection string is required. Either pass it as a parameter
|
||||
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
||||
"""
|
||||
|
||||
@@ -547,9 +564,9 @@ class PGVector(VectorStore):
|
||||
kwargs["connection_string"] = connection_string
|
||||
|
||||
return cls.from_texts(
|
||||
texts=texts,
|
||||
texts,
|
||||
embedding,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
embedding=embedding,
|
||||
distance_strategy=distance_strategy,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
|
||||
24
tests/unit_tests/vectorstores/test_pgvector.py
Normal file
24
tests/unit_tests/vectorstores/test_pgvector.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from langchain.vectorstores.pgvector import PGVECTOR_VECTOR_SIZE, PGVector
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
@pytest.mark.requires("pgvector")
|
||||
def test_embedding_store_init_defaults() -> None:
|
||||
expected = PGVECTOR_VECTOR_SIZE
|
||||
actual = PGVector(
|
||||
"postgresql+psycopg2://admin:admin@localhost:5432/mydatabase", FakeEmbeddings()
|
||||
).EmbeddingStore.embedding.type.dim
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.requires("pgvector")
|
||||
def test_embedding_store_init_vector_size() -> None:
|
||||
expected = 2
|
||||
actual = PGVector(
|
||||
"postgresql+psycopg2://admin:admin@localhost:5432/mydatabase",
|
||||
FakeEmbeddings(),
|
||||
vector_size=2,
|
||||
).EmbeddingStore.embedding.type.dim
|
||||
assert expected == actual
|
||||
Reference in New Issue
Block a user