mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
add faiss local saving/loading (#676)
- This uses the faiss built-in `write_index` and `load_index` to save and load faiss indexes locally - Also fixes #674 - The save/load functions also use the faiss library, so I refactored the dependency into a function
This commit is contained in:
parent
e45f7e40e8
commit
e04b063ff4
@ -14,6 +14,19 @@ from langchain.vectorstores.base import VectorStore
|
|||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
|
||||||
|
def dependable_faiss_import() -> Any:
|
||||||
|
"""Import faiss if available, otherwise raise error."""
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import faiss python package. "
|
||||||
|
"Please it install it with `pip install faiss` "
|
||||||
|
"or `pip install faiss-cpu` (depending on Python version)."
|
||||||
|
)
|
||||||
|
return faiss
|
||||||
|
|
||||||
|
|
||||||
class FAISS(VectorStore):
|
class FAISS(VectorStore):
|
||||||
"""Wrapper around FAISS vector database.
|
"""Wrapper around FAISS vector database.
|
||||||
|
|
||||||
@ -174,14 +187,7 @@ class FAISS(VectorStore):
|
|||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
faiss = FAISS.from_texts(texts, embeddings)
|
faiss = FAISS.from_texts(texts, embeddings)
|
||||||
"""
|
"""
|
||||||
try:
|
faiss = dependable_faiss_import()
|
||||||
import faiss
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import faiss python package. "
|
|
||||||
"Please it install it with `pip install faiss` "
|
|
||||||
"or `pip install faiss-cpu` (depending on Python version)."
|
|
||||||
)
|
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||||
index.add(np.array(embeddings, dtype=np.float32))
|
index.add(np.array(embeddings, dtype=np.float32))
|
||||||
@ -194,3 +200,21 @@ class FAISS(VectorStore):
|
|||||||
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
||||||
)
|
)
|
||||||
return cls(embedding.embed_query, index, docstore, index_to_id)
|
return cls(embedding.embed_query, index, docstore, index_to_id)
|
||||||
|
|
||||||
|
def save_local(self, path: str) -> None:
|
||||||
|
"""Save FAISS index to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to save FAISS index to.
|
||||||
|
"""
|
||||||
|
faiss = dependable_faiss_import()
|
||||||
|
faiss.write_index(self.index, path)
|
||||||
|
|
||||||
|
def load_local(self, path: str) -> None:
|
||||||
|
"""Load FAISS index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to load FAISS index from.
|
||||||
|
"""
|
||||||
|
faiss = dependable_faiss_import()
|
||||||
|
self.index = faiss.read_index(path)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test FAISS functionality."""
|
"""Test FAISS functionality."""
|
||||||
|
import tempfile
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -46,9 +47,15 @@ def test_faiss_with_metadatas() -> None:
|
|||||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||||
expected_docstore = InMemoryDocstore(
|
expected_docstore = InMemoryDocstore(
|
||||||
{
|
{
|
||||||
"0": Document(page_content="foo", metadata={"page": 0}),
|
docsearch.index_to_docstore_id[0]: Document(
|
||||||
"1": Document(page_content="bar", metadata={"page": 1}),
|
page_content="foo", metadata={"page": 0}
|
||||||
"2": Document(page_content="baz", metadata={"page": 2}),
|
),
|
||||||
|
docsearch.index_to_docstore_id[1]: Document(
|
||||||
|
page_content="bar", metadata={"page": 1}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[2]: Document(
|
||||||
|
page_content="baz", metadata={"page": 2}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
@ -82,3 +89,15 @@ def test_faiss_add_texts_not_supported() -> None:
|
|||||||
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
|
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
docsearch.add_texts(["foo"])
|
docsearch.add_texts(["foo"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_local_save_load() -> None:
|
||||||
|
"""Test end to end serialization."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile() as temp_file:
|
||||||
|
docsearch.save_local(temp_file.name)
|
||||||
|
docsearch.index = None
|
||||||
|
docsearch.load_local(temp_file.name)
|
||||||
|
assert docsearch.index is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user