add filter to sklearn vector store functions (#8113)

# What
- This is to add filter option to sklearn vectore store functions

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: Add filter to sklearn vectore store functions.
  - Issue: None
  - Dependencies: None
  - Tag maintainer: @rlancemartin, @eyurtsev
  - Twitter handle: @MlopsJ

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
shibuiwilliam 2023-08-04 15:06:41 +09:00 committed by GitHub
parent 2759e2d857
commit 0f0ccfe7f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 26 deletions

View File

@ -13,7 +13,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -56,7 +56,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -65,7 +65,7 @@
"from langchain.vectorstores import SKLearnVectorStore\n", "from langchain.vectorstores import SKLearnVectorStore\n",
"from langchain.document_loaders import TextLoader\n", "from langchain.document_loaders import TextLoader\n",
"\n", "\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n", "loader = TextLoader(\"../../../extras/modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n", "documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n", "docs = text_splitter.split_documents(documents)\n",
@ -81,7 +81,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -100,6 +100,7 @@
], ],
"source": [ "source": [
"import tempfile\n", "import tempfile\n",
"import os\n",
"\n", "\n",
"persist_path = os.path.join(tempfile.gettempdir(), \"union.parquet\")\n", "persist_path = os.path.join(tempfile.gettempdir(), \"union.parquet\")\n",
"\n", "\n",
@ -184,6 +185,32 @@
"print(docs[0].page_content)" "print(docs[0].page_content)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Filter"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
}
],
"source": [
"_filter = {\"id\": \"c53e6eac-0070-403c-8435-a9e528539610\"}\n",
"docs = vector_store.similarity_search(query, filter=_filter)\n",
"print(len(docs))"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -217,7 +244,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.6" "version": "3.10.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -233,33 +233,66 @@ class SKLearnVectorStore(VectorStore):
return list(zip(neigh_idxs[0], neigh_dists[0])) return list(zip(neigh_idxs[0], neigh_dists[0]))
def similarity_search_with_score( def similarity_search_with_score(
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any self,
query: str,
*,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
query_embedding = self._embedding_function.embed_query(query) query_embedding = self._embedding_function.embed_query(query)
indices_dists = self._similarity_index_search_with_score( indices_dists = self._similarity_index_search_with_score(
query_embedding, k=k, **kwargs query_embedding, k=fetch_k, **kwargs
) )
return [
( docs: List[Tuple[Document, float]] = []
for idx, dist in indices_dists:
doc = (
Document( Document(
page_content=self._texts[idx], page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]}, metadata={"id": self._ids[idx], **self._metadatas[idx]},
), ),
dist, dist,
) )
for idx, dist in indices_dists
] if filter is None:
docs.append(doc)
else:
filter = {
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(
doc[0].metadata.get(key) in value for key, value in filter.items()
):
docs.append(doc)
return docs[:k]
def similarity_search( def similarity_search(
self, query: str, k: int = DEFAULT_K, **kwargs: Any self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs) docs_scores = self.similarity_search_with_score(
query, k=k, fetch_k=fetch_k, filter=filter, **kwargs
)
return [doc for doc, _ in docs_scores] return [doc for doc, _ in docs_scores]
def _similarity_search_with_relevance_scores( def _similarity_search_with_relevance_scores(
self, query: str, k: int = DEFAULT_K, **kwargs: Any self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs) docs_dists = self.similarity_search_with_score(
query, k=k, fetch_k=fetch_k, filter=filter, **kwargs
)
docs, dists = zip(*docs_dists) docs, dists = zip(*docs_dists)
scores = [1 / math.exp(dist) for dist in dists] scores = [1 / math.exp(dist) for dist in dists]
return list(zip(list(docs), scores)) return list(zip(list(docs), scores))
@ -270,6 +303,7 @@ class SKLearnVectorStore(VectorStore):
k: int = DEFAULT_K, k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K, fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -283,6 +317,7 @@ class SKLearnVectorStore(VectorStore):
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Defaults to 0.5. Defaults to 0.5.
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
@ -294,17 +329,28 @@ class SKLearnVectorStore(VectorStore):
mmr_selected = maximal_marginal_relevance( mmr_selected = maximal_marginal_relevance(
self._np.array(embedding, dtype=self._np.float32), self._np.array(embedding, dtype=self._np.float32),
result_embeddings, result_embeddings,
k=k, k=fetch_k,
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
) )
mmr_indices = [indices[i] for i in mmr_selected] mmr_indices = [indices[i] for i in mmr_selected]
return [
Document( docs = []
for idx in mmr_indices:
doc = Document(
page_content=self._texts[idx], page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]}, metadata={"id": self._ids[idx], **self._metadatas[idx]},
) )
for idx in mmr_indices if filter is None:
] docs.append(doc)
else:
filter = {
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(doc.metadata.get(key) in value for key, value in filter.items()):
docs.append(doc)
return docs[:k]
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, self,
@ -312,6 +358,7 @@ class SKLearnVectorStore(VectorStore):
k: int = DEFAULT_K, k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K, fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -325,6 +372,7 @@ class SKLearnVectorStore(VectorStore):
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Defaults to 0.5. Defaults to 0.5.
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
@ -335,7 +383,7 @@ class SKLearnVectorStore(VectorStore):
embedding = self._embedding_function.embed_query(query) embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector( docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mul=lambda_mult embedding, k, fetch_k, lambda_mul=lambda_mult, filter=filter, **kwargs
) )
return docs return docs

View File

@ -12,7 +12,7 @@ def test_sklearn() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings()) docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1, fetch_k=3)
assert len(output) == 1 assert len(output) == 1
assert output[0].page_content == "foo" assert output[0].page_content == "foo"
@ -27,10 +27,24 @@ def test_sklearn_with_metadatas() -> None:
FakeEmbeddings(), FakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1, fetch_k=3)
assert output[0].metadata["page"] == "0" assert output[0].metadata["page"] == "0"
@pytest.mark.requires("numpy", "sklearn")
def test_sklearn_with_metadatas_and_filter() -> None:
"""Test end to end construction and search."""
texts = ["foo", "foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = SKLearnVectorStore.from_texts(
texts,
FakeEmbeddings(),
metadatas=metadatas,
)
output = docsearch.similarity_search("foo", k=1, fetch_k=4, filter={"page": "1"})
assert output[0].metadata["page"] == "1"
@pytest.mark.requires("numpy", "sklearn") @pytest.mark.requires("numpy", "sklearn")
def test_sklearn_with_metadatas_with_scores() -> None: def test_sklearn_with_metadatas_with_scores() -> None:
"""Test end to end construction and scored search.""" """Test end to end construction and scored search."""
@ -41,7 +55,7 @@ def test_sklearn_with_metadatas_with_scores() -> None:
FakeEmbeddings(), FakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
) )
output = docsearch.similarity_search_with_relevance_scores("foo", k=1) output = docsearch.similarity_search_with_relevance_scores("foo", k=1, fetch_k=3)
assert len(output) == 1 assert len(output) == 1
doc, score = output[0] doc, score = output[0]
assert doc.page_content == "foo" assert doc.page_content == "foo"
@ -61,7 +75,7 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
serializer="json", serializer="json",
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1, fetch_k=3)
assert len(output) == 1 assert len(output) == 1
assert output[0].page_content == "foo" assert output[0].page_content == "foo"
@ -71,7 +85,7 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
docsearch = SKLearnVectorStore( docsearch = SKLearnVectorStore(
FakeEmbeddings(), persist_path=str(persist_path), serializer="json" FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1, fetch_k=3)
assert len(output) == 1 assert len(output) == 1
assert output[0].page_content == "foo" assert output[0].page_content == "foo"
@ -98,3 +112,19 @@ def test_sklearn_mmr_by_vector() -> None:
) )
assert len(output) == 1 assert len(output) == 1
assert output[0].page_content == "foo" assert output[0].page_content == "foo"
@pytest.mark.requires("numpy", "sklearn")
def test_sklearn_mmr_with_metadata_and_filter() -> None:
"""Test end to end construction and search."""
texts = ["foo", "foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = SKLearnVectorStore.from_texts(
texts, FakeEmbeddings(), metadatas=metadatas
)
output = docsearch.max_marginal_relevance_search(
"foo", k=1, fetch_k=4, filter={"page": "1"}
)
assert len(output) == 1
assert output[0].page_content == "foo"
assert output[0].metadata["page"] == "1"