diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index db4f7196e44..a70978930f6 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -1,6 +1,6 @@ """Interface for vector stores.""" from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any, List, Optional from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -16,6 +16,10 @@ class VectorStore(ABC): @classmethod @abstractmethod def from_texts( - cls, texts: List[str], embedding: Embeddings, **kwargs: Any + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any ) -> "VectorStore": """Return VectorStore initialized from texts and embeddings.""" diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 549277b3f94..90078f74c65 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -1,6 +1,6 @@ """Wrapper around Elasticsearch vector database.""" import uuid -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -78,13 +78,19 @@ class ElasticVectorSearch(VectorStore): embedding = self.embedding_function(query) script_query = _default_script_query(embedding) response = self.client.search(index=self.index_name, query=script_query) - texts = [hit["_source"]["text"] for hit in response["hits"]["hits"][:k]] - documents = [Document(page_content=text) for text in texts] + hits = [hit["_source"] for hit in response["hits"]["hits"][:k]] + documents = [ + Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits + ] return documents @classmethod def from_texts( - cls, texts: List[str], embedding: Embeddings, **kwargs: Any + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, ) -> "ElasticVectorSearch": """Construct ElasticVectorSearch wrapper from raw documents. @@ -138,6 +144,7 @@ class ElasticVectorSearch(VectorStore): "_index": index_name, "vector": embeddings[i], "text": text, + "metadata": metadatas[i] if metadatas else {}, } requests.append(request) bulk(client, requests) diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 937ad80e731..8ae2e3f0217 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -1,5 +1,5 @@ """Wrapper around FAISS vector database.""" -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional import numpy as np @@ -54,7 +54,11 @@ class FAISS(VectorStore): @classmethod def from_texts( - cls, texts: List[str], embedding: Embeddings, **kwargs: Any + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, ) -> "FAISS": """Construct FAISS wrapper from raw documents. @@ -84,6 +88,9 @@ class FAISS(VectorStore): embeddings = embedding.embed_documents(texts) index = faiss.IndexFlatL2(len(embeddings[0])) index.add(np.array(embeddings, dtype=np.float32)) - documents = [Document(page_content=text) for text in texts] + documents = [] + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + documents.append(Document(page_content=text, metadata=metadata)) docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)}) return cls(embedding.embed_query, index, docstore) diff --git a/tests/integration_tests/vectorstores/__init__.py b/tests/integration_tests/vectorstores/__init__.py new file mode 100644 index 00000000000..4396bb23049 --- /dev/null +++ b/tests/integration_tests/vectorstores/__init__.py @@ -0,0 +1 @@ +"""Test vectorstores.""" diff --git a/tests/integration_tests/vectorstores/test_elasticsearch.py b/tests/integration_tests/vectorstores/test_elasticsearch.py new file mode 100644 index 00000000000..d3fd801f0a1 --- /dev/null +++ b/tests/integration_tests/vectorstores/test_elasticsearch.py @@ -0,0 +1,42 @@ +"""Test ElasticSearch functionality.""" +from typing import List + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [[1.0] * 9 + [i] for i in range(len(texts))] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [1.0] * 9 + [0.0] + + +def test_elasticsearch() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = ElasticVectorSearch.from_texts( + texts, FakeEmbeddings(), elasticsearch_url="http://localhost:9200" + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_elasticsearch_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = ElasticVectorSearch.from_texts( + texts, + FakeEmbeddings(), + metadatas=metadatas, + elasticsearch_url="http://localhost:9200", + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": 0})] diff --git a/tests/integration_tests/test_faiss.py b/tests/integration_tests/vectorstores/test_faiss.py similarity index 68% rename from tests/integration_tests/test_faiss.py rename to tests/integration_tests/vectorstores/test_faiss.py index 8df51989cbf..2b3cbd1d51c 100644 --- a/tests/integration_tests/test_faiss.py +++ b/tests/integration_tests/vectorstores/test_faiss.py @@ -37,6 +37,23 @@ def test_faiss() -> None: assert output == [Document(page_content="foo")] +def test_faiss_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) + expected_docstore = InMemoryDocstore( + { + "0": Document(page_content="foo", metadata={"page": 0}), + "1": Document(page_content="bar", metadata={"page": 1}), + "2": Document(page_content="baz", metadata={"page": 2}), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + def test_faiss_search_not_found() -> None: """Test what happens when document is not found.""" texts = ["foo", "bar", "baz"]