mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-15 12:02:11 +00:00
Harrison/add vald (#10807)
Co-authored-by: datelier <57349093+datelier@users.noreply.github.com>
This commit is contained in:
parent
bbc3fe259b
commit
d2bee34d4c
175
docs/extras/integrations/vectorstores/vald.ipynb
Normal file
175
docs/extras/integrations/vectorstores/vald.ipynb
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "25bce5eb-8599-40fe-947e-4932cfae8184",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Vald\n",
|
||||||
|
"\n",
|
||||||
|
"> [Vald](https://github.com/vdaas/vald) is a highly scalable distributed fast approximate nearest neighbor (ANN) dense vector search engine.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use functionality related to the `Vald` database.\n",
|
||||||
|
"\n",
|
||||||
|
"To run this notebook you need a running Vald cluster.\n",
|
||||||
|
"Check [Get Started](https://github.com/vdaas/vald#get-started) for more information.\n",
|
||||||
|
"\n",
|
||||||
|
"See the [installation instructions](https://github.com/vdaas/vald-client-python#install)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f45f46f2-7229-4859-9797-30bbead1b8e0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install vald-client-python"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2f65caa9-8383-409a-bccb-6e91fc8d5e8f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Basic Example"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "eab0b1e4-9793-4be7-a2ba-e4455c21ea22",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
||||||
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
|
"from langchain.vectorstores import Vald\n",
|
||||||
|
"\n",
|
||||||
|
"raw_documents = TextLoader('state_of_the_union.txt').load()\n",
|
||||||
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
|
"documents = text_splitter.split_documents(raw_documents)\n",
|
||||||
|
"embeddings = HuggingFaceEmbeddings()\n",
|
||||||
|
"db = Vald.from_documents(documents, embeddings, host=\"localhost\", port=8080)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b0a6797c-2bb0-45db-a636-5d2437f7a4c0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs = db.similarity_search(query)\n",
|
||||||
|
"docs[0].page_content"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c4c4e06d-6def-44ce-ac9a-4c01673c29a2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Similarity search by vector"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1eb72610-d451-4158-880c-9f0d45fa5909",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embedding_vector = embeddings.embed_query(query)\n",
|
||||||
|
"docs = db.similarity_search_by_vector(embedding_vector)\n",
|
||||||
|
"docs[0].page_content"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d33588d4-67c2-4bd3-b251-76ae783cbafb",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Similarity search with score"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1a41e382-0336-4e6d-b2ef-44cc77db2696",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"docs_and_scores = db.similarity_search_with_score(query)\n",
|
||||||
|
"docs_and_scores[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "57f930f2-41a0-4795-ad9e-44a33c8f88ec",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Maximal Marginal Relevance Search (MMR)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4790e437-3207-45cb-b121-d857ab5aabd8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "495754b1-5cdb-4af6-9733-f68700bb7232",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = db.as_retriever(search_type=\"mmr\")\n",
|
||||||
|
"retriever.get_relevant_documents(query)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e213d957-e439-4bd6-90f2-8909323f5f09",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Or use `max_marginal_relevance_search` directly:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "99d928d0-3b79-4588-925e-32230e12af47",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"db.max_marginal_relevance_search(query, k=2, fetch_k=10)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -71,6 +71,7 @@ from langchain.vectorstores.tencentvectordb import TencentVectorDB
|
|||||||
from langchain.vectorstores.tigris import Tigris
|
from langchain.vectorstores.tigris import Tigris
|
||||||
from langchain.vectorstores.typesense import Typesense
|
from langchain.vectorstores.typesense import Typesense
|
||||||
from langchain.vectorstores.usearch import USearch
|
from langchain.vectorstores.usearch import USearch
|
||||||
|
from langchain.vectorstores.vald import Vald
|
||||||
from langchain.vectorstores.vectara import Vectara
|
from langchain.vectorstores.vectara import Vectara
|
||||||
from langchain.vectorstores.weaviate import Weaviate
|
from langchain.vectorstores.weaviate import Weaviate
|
||||||
from langchain.vectorstores.zep import ZepVectorStore
|
from langchain.vectorstores.zep import ZepVectorStore
|
||||||
@ -133,6 +134,7 @@ __all__ = [
|
|||||||
"Tigris",
|
"Tigris",
|
||||||
"Typesense",
|
"Typesense",
|
||||||
"USearch",
|
"USearch",
|
||||||
|
"Vald",
|
||||||
"Vectara",
|
"Vectara",
|
||||||
"VectorStore",
|
"VectorStore",
|
||||||
"Weaviate",
|
"Weaviate",
|
||||||
|
375
libs/langchain/langchain/vectorstores/vald.py
Normal file
375
libs/langchain/langchain/vectorstores/vald.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
"""Wrapper around Vald vector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.schema.embeddings import Embeddings
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
|
||||||
|
class Vald(VectorStore):
|
||||||
|
"""Wrapper around Vald vector database.
|
||||||
|
|
||||||
|
To use, you should have the ``vald-client-python`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
from langchain.vectorstores import Vald
|
||||||
|
|
||||||
|
texts = ['foo', 'bar', 'baz']
|
||||||
|
vald = Vald.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
embedding=HuggingFaceEmbeddings(),
|
||||||
|
host="localhost",
|
||||||
|
port=8080,
|
||||||
|
skip_strict_exist_check=False,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding: Embeddings,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 8080,
|
||||||
|
grpc_options: Tuple = (
|
||||||
|
("grpc.keepalive_time_ms", 1000 * 10),
|
||||||
|
("grpc.keepalive_timeout_ms", 1000 * 10),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
self._embedding = embedding
|
||||||
|
self.target = host + ":" + str(port)
|
||||||
|
self.grpc_options = grpc_options
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> Optional[Embeddings]:
|
||||||
|
return self._embedding
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
skip_strict_exist_check: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
skip_strict_exist_check: Deprecated. This is not used basically.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from vald.v1.payload import payload_pb2
|
||||||
|
from vald.v1.vald import upsert_pb2_grpc
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import vald-client-python python package. "
|
||||||
|
"Please install it with `pip install vald-client-python`."
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
|
||||||
|
# Depending on the network quality,
|
||||||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||||||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||||||
|
stub = upsert_pb2_grpc.UpsertStub(channel)
|
||||||
|
cfg = payload_pb2.Upsert.Config(skip_strict_exist_check=skip_strict_exist_check)
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
embs = self._embedding.embed_documents(list(texts))
|
||||||
|
for text, emb in zip(texts, embs):
|
||||||
|
vec = payload_pb2.Object.Vector(id=text, vector=emb)
|
||||||
|
res = stub.Upsert(payload_pb2.Upsert.Request(vector=vec, config=cfg))
|
||||||
|
ids.append(res.uuid)
|
||||||
|
|
||||||
|
channel.close()
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
skip_strict_exist_check: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
skip_strict_exist_check: Deprecated. This is not used basically.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from vald.v1.payload import payload_pb2
|
||||||
|
from vald.v1.vald import remove_pb2_grpc
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import vald-client-python python package. "
|
||||||
|
"Please install it with `pip install vald-client-python`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
raise ValueError("No ids provided to delete")
|
||||||
|
|
||||||
|
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
|
||||||
|
# Depending on the network quality,
|
||||||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||||||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||||||
|
stub = remove_pb2_grpc.RemoveStub(channel)
|
||||||
|
cfg = payload_pb2.Remove.Config(skip_strict_exist_check=skip_strict_exist_check)
|
||||||
|
|
||||||
|
for _id in ids:
|
||||||
|
oid = payload_pb2.Object.ID(id=_id)
|
||||||
|
_ = stub.Remove(payload_pb2.Remove.Request(id=oid, config=cfg))
|
||||||
|
|
||||||
|
channel.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
radius: float = -1.0,
|
||||||
|
epsilon: float = 0.01,
|
||||||
|
timeout: int = 3000000000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
docs_and_scores = self.similarity_search_with_score(
|
||||||
|
query, k, radius, epsilon, timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for doc, _ in docs_and_scores:
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
radius: float = -1.0,
|
||||||
|
epsilon: float = 0.01,
|
||||||
|
timeout: int = 3000000000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
emb = self._embedding.embed_query(query)
|
||||||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
|
emb, k, radius, epsilon, timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
return docs_and_scores
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
radius: float = -1.0,
|
||||||
|
epsilon: float = 0.01,
|
||||||
|
timeout: int = 3000000000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding, k, radius, epsilon, timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for doc, _ in docs_and_scores:
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
radius: float = -1.0,
|
||||||
|
epsilon: float = 0.01,
|
||||||
|
timeout: int = 3000000000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from vald.v1.payload import payload_pb2
|
||||||
|
from vald.v1.vald import search_pb2_grpc
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import vald-client-python python package. "
|
||||||
|
"Please install it with `pip install vald-client-python`."
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
|
||||||
|
# Depending on the network quality,
|
||||||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||||||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||||||
|
stub = search_pb2_grpc.SearchStub(channel)
|
||||||
|
cfg = payload_pb2.Search.Config(
|
||||||
|
num=k, radius=radius, epsilon=epsilon, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
res = stub.Search(payload_pb2.Search.Request(vector=embedding, config=cfg))
|
||||||
|
|
||||||
|
docs_and_scores = []
|
||||||
|
for result in res.results:
|
||||||
|
docs_and_scores.append((Document(page_content=result.id), result.distance))
|
||||||
|
|
||||||
|
channel.close()
|
||||||
|
return docs_and_scores
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
radius: float = -1.0,
|
||||||
|
epsilon: float = 0.01,
|
||||||
|
timeout: int = 3000000000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
emb = self._embedding.embed_query(query)
|
||||||
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
|
emb,
|
||||||
|
k=k,
|
||||||
|
fetch_k=fetch_k,
|
||||||
|
radius=radius,
|
||||||
|
epsilon=epsilon,
|
||||||
|
timeout=timeout,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
radius: float = -1.0,
|
||||||
|
epsilon: float = 0.01,
|
||||||
|
timeout: int = 3000000000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from vald.v1.payload import payload_pb2
|
||||||
|
from vald.v1.vald import object_pb2_grpc
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import vald-client-python python package. "
|
||||||
|
"Please install it with `pip install vald-client-python`."
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
|
||||||
|
# Depending on the network quality,
|
||||||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||||||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||||||
|
stub = object_pb2_grpc.ObjectStub(channel)
|
||||||
|
|
||||||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding, fetch_k=fetch_k, radius=radius, epsilon=epsilon, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
embs = []
|
||||||
|
for doc, _ in docs_and_scores:
|
||||||
|
vec = stub.GetObject(
|
||||||
|
payload_pb2.Object.VectorRequest(
|
||||||
|
id=payload_pb2.Object.ID(id=doc.page_content)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
embs.append(vec.vector)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
mmr = maximal_marginal_relevance(
|
||||||
|
np.array(embedding),
|
||||||
|
embs,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
channel.close()
|
||||||
|
return [docs[i] for i in mmr]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls: Type[Vald],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 8080,
|
||||||
|
grpc_options: Tuple = (
|
||||||
|
("grpc.keepalive_time_ms", 1000 * 10),
|
||||||
|
("grpc.keepalive_timeout_ms", 1000 * 10),
|
||||||
|
),
|
||||||
|
skip_strict_exist_check: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Vald:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
skip_strict_exist_check: Deprecated. This is not used basically.
|
||||||
|
"""
|
||||||
|
vald = cls(
|
||||||
|
embedding=embedding,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
grpc_options=grpc_options,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
vald.add_texts(
|
||||||
|
texts=texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
skip_strict_exist_check=skip_strict_exist_check,
|
||||||
|
)
|
||||||
|
return vald
|
||||||
|
|
||||||
|
|
||||||
|
"""We will support if there are any requests."""
|
||||||
|
# async def aadd_texts(
|
||||||
|
# self,
|
||||||
|
# texts: Iterable[str],
|
||||||
|
# metadatas: Optional[List[dict]] = None,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> List[str]:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def _similarity_search_with_relevance_scores(
|
||||||
|
# self,
|
||||||
|
# query: str,
|
||||||
|
# k: int = 4,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> List[Tuple[Document, float]]:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# def similarity_search_with_relevance_scores(
|
||||||
|
# self,
|
||||||
|
# query: str,
|
||||||
|
# k: int = 4,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> List[Tuple[Document, float]]:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# async def amax_marginal_relevance_search_by_vector(
|
||||||
|
# self,
|
||||||
|
# embedding: List[float],
|
||||||
|
# k: int = 4,
|
||||||
|
# fetch_k: int = 20,
|
||||||
|
# lambda_mult: float = 0.5,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> List[Document]:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# @classmethod
|
||||||
|
# async def afrom_texts(
|
||||||
|
# cls: Type[VST],
|
||||||
|
# texts: List[str],
|
||||||
|
# embedding: Embeddings,
|
||||||
|
# metadatas: Optional[List[dict]] = None,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> VST:
|
||||||
|
# pass
|
170
libs/langchain/tests/integration_tests/vectorstores/test_vald.py
Normal file
170
libs/langchain/tests/integration_tests/vectorstores/test_vald.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
"""Test Vald functionality."""
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.vectorstores import Vald
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
FakeEmbeddings,
|
||||||
|
fake_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run, you should have a Vald cluster.
|
||||||
|
https://github.com/vdaas/vald/blob/main/docs/tutorial/get-started.md
|
||||||
|
"""
|
||||||
|
|
||||||
|
WAIT_TIME = 90
|
||||||
|
|
||||||
|
|
||||||
|
def _vald_from_texts(
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 8080,
|
||||||
|
skip_strict_exist_check: bool = True,
|
||||||
|
) -> Vald:
|
||||||
|
return Vald.from_texts(
|
||||||
|
fake_texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
skip_strict_exist_check=skip_strict_exist_check,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_add_texts() -> None:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME) # Wait for CreateIndex
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 3
|
||||||
|
|
||||||
|
texts = ["a", "b", "c"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch.add_texts(texts, metadatas)
|
||||||
|
time.sleep(WAIT_TIME) # Wait for CreateIndex
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_delete() -> None:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 3
|
||||||
|
|
||||||
|
docsearch.delete(["foo"])
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_search() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
docsearch = _vald_from_texts()
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=3)
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_search_with_score() -> None:
|
||||||
|
"""Test end to end construction and search with scores."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||||
|
docs = [o[0] for o in output]
|
||||||
|
scores = [o[1] for o in output]
|
||||||
|
|
||||||
|
assert docs == [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
]
|
||||||
|
assert scores[0] < scores[1] < scores[2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_search_by_vector() -> None:
|
||||||
|
"""Test end to end construction and search by vector."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
embedding = FakeEmbeddings().embed_query("foo")
|
||||||
|
output = docsearch.similarity_search_by_vector(embedding, k=3)
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_search_with_score_by_vector() -> None:
|
||||||
|
"""Test end to end construction and search with scores by vector."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
embedding = FakeEmbeddings().embed_query("foo")
|
||||||
|
output = docsearch.similarity_search_with_score_by_vector(embedding, k=3)
|
||||||
|
docs = [o[0] for o in output]
|
||||||
|
scores = [o[1] for o in output]
|
||||||
|
|
||||||
|
assert docs == [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
]
|
||||||
|
assert scores[0] < scores[1] < scores[2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_max_marginal_relevance_search() -> None:
|
||||||
|
"""Test end to end construction and MRR search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_vald_max_marginal_relevance_search_by_vector() -> None:
|
||||||
|
"""Test end to end construction and MRR search by vector."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vald_from_texts(metadatas=metadatas)
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
embedding = FakeEmbeddings().embed_query("foo")
|
||||||
|
output = docsearch.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding, k=2, fetch_k=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user