Update VectorStore interface to contain from_texts, enforce common in… (#97)

…terface
This commit is contained in:
Samantha Whitmore 2022-11-08 21:55:22 -08:00 committed by GitHub
parent 61f12229df
commit 2ddab88c06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 17 deletions

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"id": "965eecee", "id": "965eecee",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -15,7 +15,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"id": "68481687", "id": "68481687",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -30,7 +30,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "015f4ff5", "id": "015f4ff5",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -43,7 +43,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"id": "67baf32e", "id": "67baf32e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -69,12 +69,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 6,
"id": "4906b8a3", "id": "4906b8a3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"docsearch = ElasticVectorSearch.from_texts(\"http://localhost:9200\", texts, embeddings)\n", "docsearch = ElasticVectorSearch.from_texts(texts, embeddings, elasticsearch_url=\"http://localhost:9200\")\n",
"\n", "\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(query)" "docs = docsearch.similarity_search(query)"
@ -82,7 +82,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 7,
"id": "95f9eee9", "id": "95f9eee9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [

View File

@ -1,8 +1,9 @@
"""Interface for vector stores.""" """Interface for vector stores."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import Any, List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
class VectorStore(ABC): class VectorStore(ABC):
@ -11,3 +12,10 @@ class VectorStore(ABC):
@abstractmethod @abstractmethod
def similarity_search(self, query: str, k: int = 4) -> List[Document]: def similarity_search(self, query: str, k: int = 4) -> List[Document]:
"""Return docs most similar to query.""" """Return docs most similar to query."""
@classmethod
@abstractmethod
def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
) -> "VectorStore":
"""Return VectorStore initialized from texts and embeddings."""

View File

@ -1,6 +1,7 @@
"""Wrapper around Elasticsearch vector database.""" """Wrapper around Elasticsearch vector database."""
import os
import uuid import uuid
from typing import Callable, Dict, List from typing import Any, Callable, Dict, List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -46,7 +47,7 @@ class ElasticVectorSearch(VectorStore):
def __init__( def __init__(
self, self,
elastic_url: str, elasticsearch_url: str,
index_name: str, index_name: str,
mapping: Dict, mapping: Dict,
embedding_function: Callable, embedding_function: Callable,
@ -62,7 +63,7 @@ class ElasticVectorSearch(VectorStore):
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.index_name = index_name self.index_name = index_name
try: try:
es_client = elasticsearch.Elasticsearch(elastic_url) # noqa es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
"Your elasticsearch client string is misformatted. " f"Got error: {e} " "Your elasticsearch client string is misformatted. " f"Got error: {e} "
@ -89,7 +90,7 @@ class ElasticVectorSearch(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, elastic_url: str, texts: List[str], embedding: Embeddings cls, texts: List[str], embedding: Embeddings, **kwargs: Any
) -> "ElasticVectorSearch": ) -> "ElasticVectorSearch":
"""Construct ElasticVectorSearch wrapper from raw documents. """Construct ElasticVectorSearch wrapper from raw documents.
@ -107,11 +108,21 @@ class ElasticVectorSearch(VectorStore):
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
elastic_vector_search = ElasticVectorSearch.from_texts( elastic_vector_search = ElasticVectorSearch.from_texts(
"http://localhost:9200",
texts, texts,
embeddings embeddings,
elasticsearch_url="http://localhost:9200"
) )
""" """
elasticsearch_url = kwargs.get("elasticsearch_url")
if not elasticsearch_url:
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
if elasticsearch_url is None or elasticsearch_url == "":
raise ValueError(
"Did not find Elasticsearch URL, please add an environment variable"
" `ELASTICSEARCH_URL` which contains it, or pass"
" `elasticsearch_url` as a named parameter."
)
try: try:
import elasticsearch import elasticsearch
from elasticsearch.helpers import bulk from elasticsearch.helpers import bulk
@ -121,7 +132,7 @@ class ElasticVectorSearch(VectorStore):
"Please install it with `pip install elasticearch`." "Please install it with `pip install elasticearch`."
) )
try: try:
client = elasticsearch.Elasticsearch(elastic_url) client = elasticsearch.Elasticsearch(elasticsearch_url)
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
"Your elasticsearch client string is misformatted. " f"Got error: {e} " "Your elasticsearch client string is misformatted. " f"Got error: {e} "
@ -144,4 +155,4 @@ class ElasticVectorSearch(VectorStore):
requests.append(request) requests.append(request)
bulk(client, requests) bulk(client, requests)
client.indices.refresh(index=index_name) client.indices.refresh(index=index_name)
return cls(elastic_url, index_name, mapping, embedding.embed_query) return cls(elasticsearch_url, index_name, mapping, embedding.embed_query)

View File

@ -53,7 +53,9 @@ class FAISS(VectorStore):
return docs return docs
@classmethod @classmethod
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS": def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
) -> "FAISS":
"""Construct FAISS wrapper from raw documents. """Construct FAISS wrapper from raw documents.
This is a user friendly interface that: This is a user friendly interface that: