save metadata

This commit is contained in:
Harrison Chase 2022-11-17 07:38:17 -08:00
parent 0ac08bbca6
commit c756ba12d4
6 changed files with 87 additions and 9 deletions

View File

@ -1,6 +1,6 @@
"""Interface for vector stores.""" """Interface for vector stores."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List from typing import Any, List, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -16,6 +16,10 @@ class VectorStore(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_texts( def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any
) -> "VectorStore": ) -> "VectorStore":
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""

View File

@ -1,6 +1,6 @@
"""Wrapper around Elasticsearch vector database.""" """Wrapper around Elasticsearch vector database."""
import uuid import uuid
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -78,13 +78,19 @@ class ElasticVectorSearch(VectorStore):
embedding = self.embedding_function(query) embedding = self.embedding_function(query)
script_query = _default_script_query(embedding) script_query = _default_script_query(embedding)
response = self.client.search(index=self.index_name, query=script_query) response = self.client.search(index=self.index_name, query=script_query)
texts = [hit["_source"]["text"] for hit in response["hits"]["hits"][:k]] hits = [hit["_source"] for hit in response["hits"]["hits"][:k]]
documents = [Document(page_content=text) for text in texts] documents = [
Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits
]
return documents return documents
@classmethod @classmethod
def from_texts( def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "ElasticVectorSearch": ) -> "ElasticVectorSearch":
"""Construct ElasticVectorSearch wrapper from raw documents. """Construct ElasticVectorSearch wrapper from raw documents.
@ -138,6 +144,7 @@ class ElasticVectorSearch(VectorStore):
"_index": index_name, "_index": index_name,
"vector": embeddings[i], "vector": embeddings[i],
"text": text, "text": text,
"metadata": metadatas[i] if metadatas else {},
} }
requests.append(request) requests.append(request)
bulk(client, requests) bulk(client, requests)

View File

@ -1,5 +1,5 @@
"""Wrapper around FAISS vector database.""" """Wrapper around FAISS vector database."""
from typing import Any, Callable, List from typing import Any, Callable, List, Optional
import numpy as np import numpy as np
@ -54,7 +54,11 @@ class FAISS(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "FAISS": ) -> "FAISS":
"""Construct FAISS wrapper from raw documents. """Construct FAISS wrapper from raw documents.
@ -84,6 +88,9 @@ class FAISS(VectorStore):
embeddings = embedding.embed_documents(texts) embeddings = embedding.embed_documents(texts)
index = faiss.IndexFlatL2(len(embeddings[0])) index = faiss.IndexFlatL2(len(embeddings[0]))
index.add(np.array(embeddings, dtype=np.float32)) index.add(np.array(embeddings, dtype=np.float32))
documents = [Document(page_content=text) for text in texts] documents = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)}) docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
return cls(embedding.embed_query, index, docstore) return cls(embedding.embed_query, index, docstore)

View File

@ -0,0 +1 @@
"""Test vectorstores."""

View File

@ -0,0 +1,42 @@
"""Test ElasticSearch functionality."""
from typing import List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings."""
return [[1.0] * 9 + [i] for i in range(len(texts))]
def embed_query(self, text: str) -> List[float]:
"""Return simple embeddings."""
return [1.0] * 9 + [0.0]
def test_elasticsearch() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = ElasticVectorSearch.from_texts(
texts, FakeEmbeddings(), elasticsearch_url="http://localhost:9200"
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_elasticsearch_with_metadatas() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = ElasticVectorSearch.from_texts(
texts,
FakeEmbeddings(),
metadatas=metadatas,
elasticsearch_url="http://localhost:9200",
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]

View File

@ -37,6 +37,23 @@ def test_faiss() -> None:
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
def test_faiss_with_metadatas() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
"0": Document(page_content="foo", metadata={"page": 0}),
"1": Document(page_content="bar", metadata={"page": 1}),
"2": Document(page_content="baz", metadata={"page": 2}),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_faiss_search_not_found() -> None: def test_faiss_search_not_found() -> None:
"""Test what happens when document is not found.""" """Test what happens when document is not found."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]