mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
FAISS and embedding support (#48)
also adds embeddings and an in memory docstore
This commit is contained in:
parent
798deaec2b
commit
76aff023d7
@ -7,4 +7,5 @@ Welcome to LangChain
|
||||
|
||||
modules/prompt
|
||||
modules/llms
|
||||
modules/embeddings
|
||||
modules/chains
|
||||
|
5
docs/modules/embeddings.rst
Normal file
5
docs/modules/embeddings.rst
Normal file
@ -0,0 +1,5 @@
|
||||
:mod:`langchain.embeddings`
|
||||
===========================
|
||||
|
||||
.. automodule:: langchain.embeddings
|
||||
:members:
|
98
examples/embeddings.ipynb
Normal file
98
examples/embeddings.ipynb
Normal file
@ -0,0 +1,98 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "965eecee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.faiss import FAISS\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "68481687",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('state_of_the_union.txt') as f:\n",
|
||||
" state_of_the_union = f.read()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"texts = text_splitter.split_text(state_of_the_union)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"docsearch = FAISS.from_texts(texts, embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "015f4ff5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = docsearch.similarity_search(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "67baf32e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n",
|
||||
"\n",
|
||||
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
|
||||
"\n",
|
||||
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "25500fa6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -15,6 +15,7 @@ from langchain.chains import (
|
||||
SQLDatabaseChain,
|
||||
)
|
||||
from langchain.docstore import Wikipedia
|
||||
from langchain.faiss import FAISS
|
||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||
from langchain.prompt import Prompt
|
||||
from langchain.sql_database import SQLDatabase
|
||||
@ -33,4 +34,5 @@ __all__ = [
|
||||
"HuggingFaceHub",
|
||||
"SQLDatabase",
|
||||
"SQLDatabaseChain",
|
||||
"FAISS",
|
||||
]
|
||||
|
20
langchain/docstore/in_memory.py
Normal file
20
langchain/docstore/in_memory.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class InMemoryDocstore(Docstore):
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
|
||||
def __init__(self, _dict: Dict[str, Document]):
|
||||
"""Initialize with dict."""
|
||||
self._dict = _dict
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Search via direct lookup."""
|
||||
if search not in self._dict:
|
||||
return f"ID {search} not found."
|
||||
else:
|
||||
return self._dict[search]
|
4
langchain/embeddings/__init__.py
Normal file
4
langchain/embeddings/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""Wrappers around embedding modules."""
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
__all__ = ["OpenAIEmbeddings"]
|
15
langchain/embeddings/base.py
Normal file
15
langchain/embeddings/base.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Interface for embedding models."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
84
langchain/embeddings/openai.py
Normal file
84
langchain/embeddings/openai.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""Wrapper around OpenAI embedding models."""
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around OpenAI embedding models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
openai = OpenAIEmbeddings(model_name="davinci")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "babbage"
|
||||
"""Model name to use."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Did not find OpenAI API key, please add an environment variable"
|
||||
" `OPENAI_API_KEY` which contains it."
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Embedding
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
responses = [
|
||||
self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001")
|
||||
for text in texts
|
||||
]
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(
|
||||
text, engine=f"text-search-{self.model_name}-query-001"
|
||||
)
|
||||
return embedding
|
86
langchain/faiss.py
Normal file
86
langchain/faiss.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Wrapper around FAISS vector database."""
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class FAISS:
|
||||
"""Wrapper around FAISS vector database.
|
||||
|
||||
To use, you should have the ``faiss`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
faiss = FAISS(embedding_function, index, docstore)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||
docs = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
doc = self.docstore.search(str(i))
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {i}, got {doc}")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS":
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import faiss python package. "
|
||||
"Please it install it with `pip install faiss` "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
index.add(np.array(embeddings, dtype=np.float32))
|
||||
documents = [Document(page_content=text) for text in texts]
|
||||
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
||||
return cls(embedding.embed_query, index, docstore)
|
@ -12,5 +12,6 @@ google-search-results
|
||||
playwright
|
||||
wikipedia
|
||||
huggingface_hub
|
||||
faiss
|
||||
# For development
|
||||
jupyter
|
||||
|
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ setup(
|
||||
version=__version__,
|
||||
packages=find_packages(),
|
||||
description="Building applications with LLMs through composability",
|
||||
install_requires=["pydantic", "sqlalchemy"],
|
||||
install_requires=["pydantic", "sqlalchemy", "numpy"],
|
||||
long_description=long_description,
|
||||
license="MIT",
|
||||
url="https://github.com/hwchase17/langchain",
|
||||
|
@ -1,3 +1,4 @@
|
||||
-e .
|
||||
# For testing
|
||||
pytest
|
||||
pytest-dotenv
|
||||
|
1
tests/integration_tests/embeddings/__init__.py
Normal file
1
tests/integration_tests/embeddings/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Test embedding integrations."""
|
19
tests/integration_tests/embeddings/test_openai.py
Normal file
19
tests/integration_tests/embeddings/test_openai.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Test openai embeddings."""
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
def test_openai_embedding_documents() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 2048
|
||||
|
||||
|
||||
def test_openai_embedding_query() -> None:
|
||||
"""Test openai embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 2048
|
47
tests/integration_tests/test_faiss.py
Normal file
47
tests/integration_tests/test_faiss.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Test FAISS functionality."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.faiss import FAISS
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [[i] * 10 for i in range(len(texts))]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [0] * 10
|
||||
|
||||
|
||||
def test_faiss() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
"0": Document(page_content="foo"),
|
||||
"1": Document(page_content="bar"),
|
||||
"2": Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_faiss_search_not_found() -> None:
|
||||
"""Test what happens when document is not found."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
# Get rid of the docstore to purposefully induce errors.
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search("foo")
|
21
tests/unit_tests/docstore/test_inmemory.py
Normal file
21
tests/unit_tests/docstore/test_inmemory.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Test in memory docstore."""
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
|
||||
|
||||
def test_document_found() -> None:
|
||||
"""Test document found."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
output = docstore.search("foo")
|
||||
assert isinstance(output, Document)
|
||||
assert output.page_content == "bar"
|
||||
|
||||
|
||||
def test_document_not_found() -> None:
|
||||
"""Test when document is not found."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
output = docstore.search("bar")
|
||||
assert output == "ID bar not found."
|
Loading…
Reference in New Issue
Block a user