mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 03:56:39 +00:00
wip: add method for both docstore and embeddings (#119)
this will break atm but wanted to get thoughts on implementation. 1. should add() be on docstore interface? 2. should InMemoryDocstore change to take a list of documents as init? (makes this slightly easier to implement in FAISS -- if we think it is less clean then could expose a method to get the number of documents currently in the dict, and perform the logic of creating the necessary dictionary in the FAISS.add_texts method. Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
e9baf9c134
commit
315b0c09c6
@ -1,6 +1,6 @@
|
||||
"""Interface to access to place that stores documents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
@ -15,3 +15,11 @@ class Docstore(ABC):
|
||||
If page exists, return the page summary, and a Document object.
|
||||
If page does not exist, return similar entries.
|
||||
"""
|
||||
|
||||
|
||||
class AddableMixin(ABC):
|
||||
"""Mixin class that supports adding texts."""
|
||||
|
||||
@abstractmethod
|
||||
def add(self, texts: Dict[str, Document]) -> None:
|
||||
"""Add more documents."""
|
||||
|
@ -1,17 +1,24 @@
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class InMemoryDocstore(Docstore):
|
||||
class InMemoryDocstore(Docstore, AddableMixin):
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
|
||||
def __init__(self, _dict: Dict[str, Document]):
|
||||
"""Initialize with dict."""
|
||||
self._dict = _dict
|
||||
|
||||
def add(self, texts: Dict[str, Document]) -> None:
|
||||
"""Add texts to in memory dictionary."""
|
||||
overlapping = set(texts).intersection(self._dict)
|
||||
if overlapping:
|
||||
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
||||
self._dict = dict(self._dict, **texts)
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Search via direct lookup."""
|
||||
if search not in self._dict:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Interface for vector stores."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -9,6 +9,10 @@ from langchain.embeddings.base import Embeddings
|
||||
class VectorStore(ABC):
|
||||
"""Interface for vector stores."""
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -65,6 +65,28 @@ class ElasticVectorSearch(VectorStore):
|
||||
)
|
||||
self.client = es_client
|
||||
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
try:
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticearch`."
|
||||
)
|
||||
requests = []
|
||||
for i, text in enumerate(texts):
|
||||
request = {
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"vector": self.embedding_function(text),
|
||||
"text": text,
|
||||
}
|
||||
requests.append(request)
|
||||
bulk(self.client, requests)
|
||||
# TODO: add option not to refresh
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Wrapper around FAISS vector database."""
|
||||
from typing import Any, Callable, List, Optional
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -23,11 +24,41 @@ class FAISS(VectorStore):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
# Embed and create the documents.
|
||||
embeddings = [self.embedding_function(text) for text in texts]
|
||||
documents = [Document(page_content=text) for text in texts]
|
||||
# Add to the index, the index_to_id mapping, and the docstore.
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
self.index.add(np.array(embeddings, dtype=np.float32))
|
||||
# Get list of index, id, and docs.
|
||||
full_info = [
|
||||
(starting_len + i, str(uuid.uuid4()), doc)
|
||||
for i, doc in enumerate(documents)
|
||||
]
|
||||
# Add information to docstore and index.
|
||||
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
||||
index_to_id = {index: _id for index, _id, _ in full_info}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
@ -46,9 +77,10 @@ class FAISS(VectorStore):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
doc = self.docstore.search(str(i))
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {i}, got {doc}")
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -92,5 +124,8 @@ class FAISS(VectorStore):
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
||||
return cls(embedding.embed_query, index, docstore)
|
||||
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
||||
docstore = InMemoryDocstore(
|
||||
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
||||
)
|
||||
return cls(embedding.embed_query, index, docstore, index_to_id)
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.docstore.wikipedia import Wikipedia
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
|
||||
@ -25,11 +26,12 @@ def test_faiss() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
"0": Document(page_content="foo"),
|
||||
"1": Document(page_content="bar"),
|
||||
"2": Document(page_content="baz"),
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
@ -62,3 +64,21 @@ def test_faiss_search_not_found() -> None:
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search("foo")
|
||||
|
||||
|
||||
def test_faiss_add_texts() -> None:
|
||||
"""Test end to end adding of texts."""
|
||||
# Create initial doc store.
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
# Test adding a similar document as before.
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_faiss_add_texts_not_supported() -> None:
|
||||
"""Test adding of texts to a docstore that doesn't support it."""
|
||||
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.add_texts(["foo"])
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test in memory docstore."""
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
@ -19,3 +20,37 @@ def test_document_not_found() -> None:
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
output = docstore.search("bar")
|
||||
assert output == "ID bar not found."
|
||||
|
||||
|
||||
def test_adding_document() -> None:
|
||||
"""Test that documents are added correctly."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
new_dict = {"bar": Document(page_content="foo")}
|
||||
docstore.add(new_dict)
|
||||
|
||||
# Test that you can find new document.
|
||||
foo_output = docstore.search("bar")
|
||||
assert isinstance(foo_output, Document)
|
||||
assert foo_output.page_content == "foo"
|
||||
|
||||
# Test that old document is the same.
|
||||
bar_output = docstore.search("foo")
|
||||
assert isinstance(bar_output, Document)
|
||||
assert bar_output.page_content == "bar"
|
||||
|
||||
|
||||
def test_adding_document_already_exists() -> None:
|
||||
"""Test that error is raised if document id already exists."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
new_dict = {"foo": Document(page_content="foo")}
|
||||
|
||||
# Test that error is raised.
|
||||
with pytest.raises(ValueError):
|
||||
docstore.add(new_dict)
|
||||
|
||||
# Test that old document is the same.
|
||||
bar_output = docstore.search("foo")
|
||||
assert isinstance(bar_output, Document)
|
||||
assert bar_output.page_content == "bar"
|
||||
|
Loading…
Reference in New Issue
Block a user