mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
wip: add method for both docstore and embeddings (#119)
this will break atm but wanted to get thoughts on implementation. 1. should add() be on docstore interface? 2. should InMemoryDocstore change to take a list of documents as init? (makes this slightly easier to implement in FAISS -- if we think it is less clean then could expose a method to get the number of documents currently in the dict, and perform the logic of creating the necessary dictionary in the FAISS.add_texts method. Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
e9baf9c134
commit
315b0c09c6
@ -1,6 +1,6 @@
|
|||||||
"""Interface to access to place that stores documents."""
|
"""Interface to access to place that stores documents."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
@ -15,3 +15,11 @@ class Docstore(ABC):
|
|||||||
If page exists, return the page summary, and a Document object.
|
If page exists, return the page summary, and a Document object.
|
||||||
If page does not exist, return similar entries.
|
If page does not exist, return similar entries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AddableMixin(ABC):
|
||||||
|
"""Mixin class that supports adding texts."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(self, texts: Dict[str, Document]) -> None:
|
||||||
|
"""Add more documents."""
|
||||||
|
@ -1,17 +1,24 @@
|
|||||||
"""Simple in memory docstore in the form of a dict."""
|
"""Simple in memory docstore in the form of a dict."""
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import AddableMixin, Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
class InMemoryDocstore(Docstore):
|
class InMemoryDocstore(Docstore, AddableMixin):
|
||||||
"""Simple in memory docstore in the form of a dict."""
|
"""Simple in memory docstore in the form of a dict."""
|
||||||
|
|
||||||
def __init__(self, _dict: Dict[str, Document]):
|
def __init__(self, _dict: Dict[str, Document]):
|
||||||
"""Initialize with dict."""
|
"""Initialize with dict."""
|
||||||
self._dict = _dict
|
self._dict = _dict
|
||||||
|
|
||||||
|
def add(self, texts: Dict[str, Document]) -> None:
|
||||||
|
"""Add texts to in memory dictionary."""
|
||||||
|
overlapping = set(texts).intersection(self._dict)
|
||||||
|
if overlapping:
|
||||||
|
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
||||||
|
self._dict = dict(self._dict, **texts)
|
||||||
|
|
||||||
def search(self, search: str) -> Union[str, Document]:
|
def search(self, search: str) -> Union[str, Document]:
|
||||||
"""Search via direct lookup."""
|
"""Search via direct lookup."""
|
||||||
if search not in self._dict:
|
if search not in self._dict:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Interface for vector stores."""
|
"""Interface for vector stores."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -9,6 +9,10 @@ from langchain.embeddings.base import Embeddings
|
|||||||
class VectorStore(ABC):
|
class VectorStore(ABC):
|
||||||
"""Interface for vector stores."""
|
"""Interface for vector stores."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_texts(self, texts: Iterable[str]) -> None:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||||
"""Return docs most similar to query."""
|
"""Return docs most similar to query."""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Wrapper around Elasticsearch vector database."""
|
"""Wrapper around Elasticsearch vector database."""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -65,6 +65,28 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
)
|
)
|
||||||
self.client = es_client
|
self.client = es_client
|
||||||
|
|
||||||
|
def add_texts(self, texts: Iterable[str]) -> None:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||||
|
try:
|
||||||
|
from elasticsearch.helpers import bulk
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import elasticsearch python package. "
|
||||||
|
"Please install it with `pip install elasticearch`."
|
||||||
|
)
|
||||||
|
requests = []
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
request = {
|
||||||
|
"_op_type": "index",
|
||||||
|
"_index": self.index_name,
|
||||||
|
"vector": self.embedding_function(text),
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
requests.append(request)
|
||||||
|
bulk(self.client, requests)
|
||||||
|
# TODO: add option not to refresh
|
||||||
|
self.client.indices.refresh(index=self.index_name)
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
"""Wrapper around FAISS vector database."""
|
"""Wrapper around FAISS vector database."""
|
||||||
from typing import Any, Callable, List, Optional
|
import uuid
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import AddableMixin, Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -23,11 +24,41 @@ class FAISS(VectorStore):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_function: Callable,
|
||||||
|
index: Any,
|
||||||
|
docstore: Docstore,
|
||||||
|
index_to_docstore_id: Dict[int, str],
|
||||||
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.index = index
|
self.index = index
|
||||||
self.docstore = docstore
|
self.docstore = docstore
|
||||||
|
self.index_to_docstore_id = index_to_docstore_id
|
||||||
|
|
||||||
|
def add_texts(self, texts: Iterable[str]) -> None:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||||
|
if not isinstance(self.docstore, AddableMixin):
|
||||||
|
raise ValueError(
|
||||||
|
"If trying to add texts, the underlying docstore should support "
|
||||||
|
f"adding items, which {self.docstore} does not"
|
||||||
|
)
|
||||||
|
# Embed and create the documents.
|
||||||
|
embeddings = [self.embedding_function(text) for text in texts]
|
||||||
|
documents = [Document(page_content=text) for text in texts]
|
||||||
|
# Add to the index, the index_to_id mapping, and the docstore.
|
||||||
|
starting_len = len(self.index_to_docstore_id)
|
||||||
|
self.index.add(np.array(embeddings, dtype=np.float32))
|
||||||
|
# Get list of index, id, and docs.
|
||||||
|
full_info = [
|
||||||
|
(starting_len + i, str(uuid.uuid4()), doc)
|
||||||
|
for i, doc in enumerate(documents)
|
||||||
|
]
|
||||||
|
# Add information to docstore and index.
|
||||||
|
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
||||||
|
index_to_id = {index: _id for index, _id, _ in full_info}
|
||||||
|
self.index_to_docstore_id.update(index_to_id)
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
@ -46,9 +77,10 @@ class FAISS(VectorStore):
|
|||||||
if i == -1:
|
if i == -1:
|
||||||
# This happens when not enough docs are returned.
|
# This happens when not enough docs are returned.
|
||||||
continue
|
continue
|
||||||
doc = self.docstore.search(str(i))
|
_id = self.index_to_docstore_id[i]
|
||||||
|
doc = self.docstore.search(_id)
|
||||||
if not isinstance(doc, Document):
|
if not isinstance(doc, Document):
|
||||||
raise ValueError(f"Could not find document for id {i}, got {doc}")
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -92,5 +124,8 @@ class FAISS(VectorStore):
|
|||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
metadata = metadatas[i] if metadatas else {}
|
metadata = metadatas[i] if metadatas else {}
|
||||||
documents.append(Document(page_content=text, metadata=metadata))
|
documents.append(Document(page_content=text, metadata=metadata))
|
||||||
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
||||||
return cls(embedding.embed_query, index, docstore)
|
docstore = InMemoryDocstore(
|
||||||
|
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
||||||
|
)
|
||||||
|
return cls(embedding.embed_query, index, docstore, index_to_id)
|
||||||
|
@ -5,6 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
from langchain.docstore.wikipedia import Wikipedia
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores.faiss import FAISS
|
from langchain.vectorstores.faiss import FAISS
|
||||||
|
|
||||||
@ -25,11 +26,12 @@ def test_faiss() -> None:
|
|||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
index_to_id = docsearch.index_to_docstore_id
|
||||||
expected_docstore = InMemoryDocstore(
|
expected_docstore = InMemoryDocstore(
|
||||||
{
|
{
|
||||||
"0": Document(page_content="foo"),
|
index_to_id[0]: Document(page_content="foo"),
|
||||||
"1": Document(page_content="bar"),
|
index_to_id[1]: Document(page_content="bar"),
|
||||||
"2": Document(page_content="baz"),
|
index_to_id[2]: Document(page_content="baz"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
@ -62,3 +64,21 @@ def test_faiss_search_not_found() -> None:
|
|||||||
docsearch.docstore = InMemoryDocstore({})
|
docsearch.docstore = InMemoryDocstore({})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
docsearch.similarity_search("foo")
|
docsearch.similarity_search("foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_add_texts() -> None:
|
||||||
|
"""Test end to end adding of texts."""
|
||||||
|
# Create initial doc store.
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
# Test adding a similar document as before.
|
||||||
|
docsearch.add_texts(["foo"])
|
||||||
|
output = docsearch.similarity_search("foo", k=2)
|
||||||
|
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_add_texts_not_supported() -> None:
|
||||||
|
"""Test adding of texts to a docstore that doesn't support it."""
|
||||||
|
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.add_texts(["foo"])
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test in memory docstore."""
|
"""Test in memory docstore."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
@ -19,3 +20,37 @@ def test_document_not_found() -> None:
|
|||||||
docstore = InMemoryDocstore(_dict)
|
docstore = InMemoryDocstore(_dict)
|
||||||
output = docstore.search("bar")
|
output = docstore.search("bar")
|
||||||
assert output == "ID bar not found."
|
assert output == "ID bar not found."
|
||||||
|
|
||||||
|
|
||||||
|
def test_adding_document() -> None:
|
||||||
|
"""Test that documents are added correctly."""
|
||||||
|
_dict = {"foo": Document(page_content="bar")}
|
||||||
|
docstore = InMemoryDocstore(_dict)
|
||||||
|
new_dict = {"bar": Document(page_content="foo")}
|
||||||
|
docstore.add(new_dict)
|
||||||
|
|
||||||
|
# Test that you can find new document.
|
||||||
|
foo_output = docstore.search("bar")
|
||||||
|
assert isinstance(foo_output, Document)
|
||||||
|
assert foo_output.page_content == "foo"
|
||||||
|
|
||||||
|
# Test that old document is the same.
|
||||||
|
bar_output = docstore.search("foo")
|
||||||
|
assert isinstance(bar_output, Document)
|
||||||
|
assert bar_output.page_content == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_adding_document_already_exists() -> None:
|
||||||
|
"""Test that error is raised if document id already exists."""
|
||||||
|
_dict = {"foo": Document(page_content="bar")}
|
||||||
|
docstore = InMemoryDocstore(_dict)
|
||||||
|
new_dict = {"foo": Document(page_content="foo")}
|
||||||
|
|
||||||
|
# Test that error is raised.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docstore.add(new_dict)
|
||||||
|
|
||||||
|
# Test that old document is the same.
|
||||||
|
bar_output = docstore.search("foo")
|
||||||
|
assert isinstance(bar_output, Document)
|
||||||
|
assert bar_output.page_content == "bar"
|
||||||
|
Loading…
Reference in New Issue
Block a user