diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index b29e5b3d4fc..81951f490a0 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -296,8 +296,18 @@ class PGVector(VectorStore): if filter is not None: filter_clauses = [] for key, value in filter.items(): - filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value) - filter_clauses.append(filter_by_metadata) + IN = "in" + if isinstance(value, dict) and IN in map(str.lower, value): + value_case_insensitive = {k.lower(): v for k, v in value.items()} + filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_( + value_case_insensitive[IN] + ) + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str( + value + ) + filter_clauses.append(filter_by_metadata) filter_by = sqlalchemy.and_(filter_by, *filter_clauses) diff --git a/tests/integration_tests/vectorstores/test_pgvector.py b/tests/integration_tests/vectorstores/test_pgvector.py index 3560bb591ef..8ad7f1bc147 100644 --- a/tests/integration_tests/vectorstores/test_pgvector.py +++ b/tests/integration_tests/vectorstores/test_pgvector.py @@ -147,3 +147,24 @@ def test_pgvector_collection_with_metadata() -> None: else: assert collection.name == "test_collection" assert collection.cmetadata == {"foo": "bar"} + + +def test_pgvector_with_filter_in_set() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=2, filter={"page": {"IN": ["0", "2"]}} + ) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 0.0), + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), + ]