mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 18:53:02 +00:00
Update VectorStore interface to contain from_texts, enforce common in… (#97)
…terface
This commit is contained in:
parent
61f12229df
commit
2ddab88c06
@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 2,
|
||||||
"id": "965eecee",
|
"id": "965eecee",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -15,7 +15,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"id": "68481687",
|
"id": "68481687",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -30,7 +30,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "015f4ff5",
|
"id": "015f4ff5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -43,7 +43,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"id": "67baf32e",
|
"id": "67baf32e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -69,12 +69,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
"id": "4906b8a3",
|
"id": "4906b8a3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"docsearch = ElasticVectorSearch.from_texts(\"http://localhost:9200\", texts, embeddings)\n",
|
"docsearch = ElasticVectorSearch.from_texts(texts, embeddings, elasticsearch_url=\"http://localhost:9200\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
"docs = docsearch.similarity_search(query)"
|
"docs = docsearch.similarity_search(query)"
|
||||||
@ -82,7 +82,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"id": "95f9eee9",
|
"id": "95f9eee9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
"""Interface for vector stores."""
|
"""Interface for vector stores."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
class VectorStore(ABC):
|
class VectorStore(ABC):
|
||||||
@ -11,3 +12,10 @@ class VectorStore(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||||
"""Return docs most similar to query."""
|
"""Return docs most similar to query."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_texts(
|
||||||
|
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
||||||
|
) -> "VectorStore":
|
||||||
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Wrapper around Elasticsearch vector database."""
|
"""Wrapper around Elasticsearch vector database."""
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -46,7 +47,7 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
elastic_url: str,
|
elasticsearch_url: str,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
mapping: Dict,
|
mapping: Dict,
|
||||||
embedding_function: Callable,
|
embedding_function: Callable,
|
||||||
@ -62,7 +63,7 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.index_name = index_name
|
self.index_name = index_name
|
||||||
try:
|
try:
|
||||||
es_client = elasticsearch.Elasticsearch(elastic_url) # noqa
|
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
||||||
@ -89,7 +90,7 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls, elastic_url: str, texts: List[str], embedding: Embeddings
|
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
||||||
) -> "ElasticVectorSearch":
|
) -> "ElasticVectorSearch":
|
||||||
"""Construct ElasticVectorSearch wrapper from raw documents.
|
"""Construct ElasticVectorSearch wrapper from raw documents.
|
||||||
|
|
||||||
@ -107,11 +108,21 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
elastic_vector_search = ElasticVectorSearch.from_texts(
|
elastic_vector_search = ElasticVectorSearch.from_texts(
|
||||||
"http://localhost:9200",
|
|
||||||
texts,
|
texts,
|
||||||
embeddings
|
embeddings,
|
||||||
|
elasticsearch_url="http://localhost:9200"
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
elasticsearch_url = kwargs.get("elasticsearch_url")
|
||||||
|
if not elasticsearch_url:
|
||||||
|
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
|
||||||
|
|
||||||
|
if elasticsearch_url is None or elasticsearch_url == "":
|
||||||
|
raise ValueError(
|
||||||
|
"Did not find Elasticsearch URL, please add an environment variable"
|
||||||
|
" `ELASTICSEARCH_URL` which contains it, or pass"
|
||||||
|
" `elasticsearch_url` as a named parameter."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
import elasticsearch
|
import elasticsearch
|
||||||
from elasticsearch.helpers import bulk
|
from elasticsearch.helpers import bulk
|
||||||
@ -121,7 +132,7 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
"Please install it with `pip install elasticearch`."
|
"Please install it with `pip install elasticearch`."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
client = elasticsearch.Elasticsearch(elastic_url)
|
client = elasticsearch.Elasticsearch(elasticsearch_url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
||||||
@ -144,4 +155,4 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
requests.append(request)
|
requests.append(request)
|
||||||
bulk(client, requests)
|
bulk(client, requests)
|
||||||
client.indices.refresh(index=index_name)
|
client.indices.refresh(index=index_name)
|
||||||
return cls(elastic_url, index_name, mapping, embedding.embed_query)
|
return cls(elasticsearch_url, index_name, mapping, embedding.embed_query)
|
||||||
|
@ -53,7 +53,9 @@ class FAISS(VectorStore):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS":
|
def from_texts(
|
||||||
|
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
||||||
|
) -> "FAISS":
|
||||||
"""Construct FAISS wrapper from raw documents.
|
"""Construct FAISS wrapper from raw documents.
|
||||||
|
|
||||||
This is a user friendly interface that:
|
This is a user friendly interface that:
|
||||||
|
Loading…
Reference in New Issue
Block a user