mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Add PGVector collection metadata (#1887)
The `CollectionStore` for `PGVector` has a `cmetadata` field but it's never used. This PR add the ability to save metadata information to the collection.
This commit is contained in:
parent
d08f940336
commit
2212520a6c
@ -5,7 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||||
from sqlalchemy.orm import Mapped, Session, declarative_base, relationship
|
from sqlalchemy.orm import Mapped, Session, declarative_base, relationship
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@ -29,7 +29,7 @@ class CollectionStore(BaseModel):
|
|||||||
__tablename__ = "langchain_pg_collection"
|
__tablename__ = "langchain_pg_collection"
|
||||||
|
|
||||||
name = sqlalchemy.Column(sqlalchemy.String)
|
name = sqlalchemy.Column(sqlalchemy.String)
|
||||||
cmetadata = sqlalchemy.Column(sqlalchemy.JSON)
|
cmetadata = sqlalchemy.Column(JSON)
|
||||||
|
|
||||||
embeddings = relationship(
|
embeddings = relationship(
|
||||||
"EmbeddingStore",
|
"EmbeddingStore",
|
||||||
@ -57,7 +57,7 @@ class CollectionStore(BaseModel):
|
|||||||
if collection:
|
if collection:
|
||||||
return collection, created
|
return collection, created
|
||||||
|
|
||||||
collection = cls(name=name, metadata=cmetadata)
|
collection = cls(name=name, cmetadata=cmetadata)
|
||||||
session.add(collection)
|
session.add(collection)
|
||||||
session.commit()
|
session.commit()
|
||||||
created = True
|
created = True
|
||||||
@ -121,6 +121,7 @@ class PGVector(VectorStore):
|
|||||||
connection_string: str,
|
connection_string: str,
|
||||||
embedding_function: Embeddings,
|
embedding_function: Embeddings,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
collection_metadata: Optional[dict] = None,
|
||||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
logger: Optional[logging.Logger] = None,
|
logger: Optional[logging.Logger] = None,
|
||||||
@ -128,6 +129,7 @@ class PGVector(VectorStore):
|
|||||||
self.connection_string = connection_string
|
self.connection_string = connection_string
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
self.collection_metadata = collection_metadata
|
||||||
self.distance_strategy = distance_strategy
|
self.distance_strategy = distance_strategy
|
||||||
self.pre_delete_collection = pre_delete_collection
|
self.pre_delete_collection = pre_delete_collection
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
@ -168,7 +170,9 @@ class PGVector(VectorStore):
|
|||||||
if self.pre_delete_collection:
|
if self.pre_delete_collection:
|
||||||
self.delete_collection()
|
self.delete_collection()
|
||||||
with Session(self._conn) as session:
|
with Session(self._conn) as session:
|
||||||
CollectionStore.get_or_create(session, self.collection_name)
|
CollectionStore.get_or_create(
|
||||||
|
session, self.collection_name, cmetadata=self.collection_metadata
|
||||||
|
)
|
||||||
|
|
||||||
def delete_collection(self) -> None:
|
def delete_collection(self) -> None:
|
||||||
self.logger.debug("Trying to delete collection")
|
self.logger.debug("Trying to delete collection")
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores.pgvector import PGVector
|
from langchain.vectorstores.pgvector import PGVector
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
@ -79,3 +81,21 @@ def test_pgvector_with_metadatas_with_scores() -> None:
|
|||||||
)
|
)
|
||||||
output = docsearch.similarity_search_with_score("foo", k=1)
|
output = docsearch.similarity_search_with_score("foo", k=1)
|
||||||
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_collection_with_metadata() -> None:
|
||||||
|
"""Test end to end collection construction"""
|
||||||
|
pgvector = PGVector(
|
||||||
|
collection_name="test_collection",
|
||||||
|
collection_metadata={"foo": "bar"},
|
||||||
|
embedding_function=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
session = Session(pgvector.connect())
|
||||||
|
collection = pgvector.get_collection(session)
|
||||||
|
if collection is None:
|
||||||
|
assert False, "Expected a CollectionStore object but received None"
|
||||||
|
else:
|
||||||
|
assert collection.name == "test_collection"
|
||||||
|
assert collection.cmetadata == {"foo": "bar"}
|
||||||
|
Loading…
Reference in New Issue
Block a user