mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
Update VectorStore interface to contain from_texts, enforce common in… (#97)
…terface
This commit is contained in:
parent
61f12229df
commit
2ddab88c06
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "965eecee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -15,7 +15,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"id": "68481687",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -30,7 +30,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "015f4ff5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -43,7 +43,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"id": "67baf32e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -69,12 +69,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "4906b8a3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docsearch = ElasticVectorSearch.from_texts(\"http://localhost:9200\", texts, embeddings)\n",
|
||||
"docsearch = ElasticVectorSearch.from_texts(texts, embeddings, elasticsearch_url=\"http://localhost:9200\")\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = docsearch.similarity_search(query)"
|
||||
@ -82,7 +82,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"id": "95f9eee9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
@ -1,8 +1,9 @@
|
||||
"""Interface for vector stores."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class VectorStore(ABC):
|
||||
@ -11,3 +12,10 @@ class VectorStore(ABC):
|
||||
@abstractmethod
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_texts(
|
||||
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
||||
) -> "VectorStore":
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -46,7 +47,7 @@ class ElasticVectorSearch(VectorStore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
elastic_url: str,
|
||||
elasticsearch_url: str,
|
||||
index_name: str,
|
||||
mapping: Dict,
|
||||
embedding_function: Callable,
|
||||
@ -62,7 +63,7 @@ class ElasticVectorSearch(VectorStore):
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
try:
|
||||
es_client = elasticsearch.Elasticsearch(elastic_url) # noqa
|
||||
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
||||
@ -89,7 +90,7 @@ class ElasticVectorSearch(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls, elastic_url: str, texts: List[str], embedding: Embeddings
|
||||
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
||||
) -> "ElasticVectorSearch":
|
||||
"""Construct ElasticVectorSearch wrapper from raw documents.
|
||||
|
||||
@ -107,11 +108,21 @@ class ElasticVectorSearch(VectorStore):
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
elastic_vector_search = ElasticVectorSearch.from_texts(
|
||||
"http://localhost:9200",
|
||||
texts,
|
||||
embeddings
|
||||
embeddings,
|
||||
elasticsearch_url="http://localhost:9200"
|
||||
)
|
||||
"""
|
||||
elasticsearch_url = kwargs.get("elasticsearch_url")
|
||||
if not elasticsearch_url:
|
||||
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
|
||||
|
||||
if elasticsearch_url is None or elasticsearch_url == "":
|
||||
raise ValueError(
|
||||
"Did not find Elasticsearch URL, please add an environment variable"
|
||||
" `ELASTICSEARCH_URL` which contains it, or pass"
|
||||
" `elasticsearch_url` as a named parameter."
|
||||
)
|
||||
try:
|
||||
import elasticsearch
|
||||
from elasticsearch.helpers import bulk
|
||||
@ -121,7 +132,7 @@ class ElasticVectorSearch(VectorStore):
|
||||
"Please install it with `pip install elasticearch`."
|
||||
)
|
||||
try:
|
||||
client = elasticsearch.Elasticsearch(elastic_url)
|
||||
client = elasticsearch.Elasticsearch(elasticsearch_url)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
|
||||
@ -144,4 +155,4 @@ class ElasticVectorSearch(VectorStore):
|
||||
requests.append(request)
|
||||
bulk(client, requests)
|
||||
client.indices.refresh(index=index_name)
|
||||
return cls(elastic_url, index_name, mapping, embedding.embed_query)
|
||||
return cls(elasticsearch_url, index_name, mapping, embedding.embed_query)
|
||||
|
@ -53,7 +53,9 @@ class FAISS(VectorStore):
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS":
|
||||
def from_texts(
|
||||
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
|
||||
) -> "FAISS":
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
|
Loading…
Reference in New Issue
Block a user