From 235bdb9fa79cae0e5636867ae7afefc1a43f0e5a Mon Sep 17 00:00:00 2001 From: Masanori Taniguchi Date: Thu, 30 Nov 2023 12:07:29 +0900 Subject: [PATCH] Support Vald secure connection (#13269) **Description:** When using Vald, only insecure grpc connection was supported, so secure connection is now supported. In addition, grpc metadata can be added to Vald requests to enable authentication with a token. --- .../docs/integrations/vectorstores/vald.ipynb | 152 +++++++++++++++++- libs/langchain/langchain/vectorstores/vald.py | 78 +++++++-- 2 files changed, 212 insertions(+), 18 deletions(-) diff --git a/docs/docs/integrations/vectorstores/vald.ipynb b/docs/docs/integrations/vectorstores/vald.ipynb index c726f7cf4b5..34401d4e9ce 100644 --- a/docs/docs/integrations/vectorstores/vald.ipynb +++ b/docs/docs/integrations/vectorstores/vald.ipynb @@ -149,6 +149,156 @@ "source": [ "db.max_marginal_relevance_search(query, k=2, fetch_k=10)" ] + }, + { + "cell_type": "markdown", + "id": "7dc7ce16-35af-49b7-8009-7eaadb7abbcb", + "metadata": {}, + "source": [ + "## Example of using secure connection\n", + "In order to run this notebook, it is necessary to run a Vald cluster with secure connection.\n", + "\n", + "Here is an example of a Vald cluster with the following configuration using [Athenz](https://github.com/AthenZ/athenz) authentication.\n", + "\n", + "ingress(TLS) -> [authorization-proxy](https://github.com/AthenZ/authorization-proxy)(Check athenz-role-auth in grpc metadata) -> vald-lb-gateway" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6894c02d-7a86-4600-bab1-f7e9cce79333", + "metadata": {}, + "outputs": [], + "source": [ + "import grpc\n", + "\n", + "with open(\"test_root_cacert.crt\", \"rb\") as root:\n", + " credentials = grpc.ssl_channel_credentials(root_certificates=root.read())\n", + "\n", + "# Refresh is required for server use\n", + "with open(\".ztoken\", \"rb\") as ztoken:\n", + " token = ztoken.read().strip()\n", + "\n", + "metadata = [(b\"athenz-role-auth\", token)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc15c20b-485d-435e-a2ec-c7dcb9db40b5", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "from langchain.embeddings import HuggingFaceEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import Vald\n", + "\n", + "raw_documents = TextLoader(\"state_of_the_union.txt\").load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "documents = text_splitter.split_documents(raw_documents)\n", + "embeddings = HuggingFaceEmbeddings()\n", + "\n", + "db = Vald.from_documents(\n", + " documents,\n", + " embeddings,\n", + " host=\"localhost\",\n", + " port=443,\n", + " grpc_use_secure=True,\n", + " grpc_credentials=credentials,\n", + " grpc_metadata=metadata,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "069b96c6-6db2-46ce-a820-24e8933156a0", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = db.similarity_search(query, grpc_metadata=metadata)\n", + "docs[0].page_content" + ] + }, + { + "cell_type": "markdown", + "id": "8327accb-6776-4a20-a325-b5da92e3a049", + "metadata": {}, + "source": [ + "### Similarity search by vector" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0ab2a97-83e4-490d-81a5-8aaa032d8811", + "metadata": {}, + "outputs": [], + "source": [ + "embedding_vector = embeddings.embed_query(query)\n", + "docs = db.similarity_search_by_vector(embedding_vector, grpc_metadata=metadata)\n", + "docs[0].page_content" + ] + }, + { + "cell_type": "markdown", + "id": "f3f987bd-512e-4e29-acb3-e110e74b51a2", + "metadata": {}, + "source": [ + "### Similarity search with score" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88dd39bc-8764-4a8c-ac89-06e2341aefa6", + "metadata": {}, + "outputs": [], + "source": [ + "docs_and_scores = db.similarity_search_with_score(query, grpc_metadata=metadata)\n", + "docs_and_scores[0]" + ] + }, + { + "cell_type": "markdown", + "id": "fef1bd41-484e-4845-88a9-c7f504068db0", + "metadata": {}, + "source": [ + "### Maximal Marginal Relevance Search (MMR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cf08477-87b0-41ac-9536-52dec1c5d67f", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = db.as_retriever(\n", + " search_kwargs={\"search_type\": \"mmr\", \"grpc_metadata\": metadata}\n", + ")\n", + "retriever.get_relevant_documents(query, grpc_metadata=metadata)" + ] + }, + { + "cell_type": "markdown", + "id": "f994fa57-53e4-4fe6-9418-59a5136c6fe8", + "metadata": {}, + "source": [ + "Or:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2111ce42-07c7-4ccc-bdbf-459165e3a410", + "metadata": {}, + "outputs": [], + "source": [ + "db.max_marginal_relevance_search(query, k=2, fetch_k=10, grpc_metadata=metadata)" + ] } ], "metadata": { @@ -167,7 +317,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/vectorstores/vald.py b/libs/langchain/langchain/vectorstores/vald.py index a10bd2ea1b7..25d8416221e 100644 --- a/libs/langchain/langchain/vectorstores/vald.py +++ b/libs/langchain/langchain/vectorstores/vald.py @@ -41,19 +41,40 @@ class Vald(VectorStore): ("grpc.keepalive_time_ms", 1000 * 10), ("grpc.keepalive_timeout_ms", 1000 * 10), ), + grpc_use_secure: bool = False, + grpc_credentials: Optional[Any] = None, ): self._embedding = embedding self.target = host + ":" + str(port) self.grpc_options = grpc_options + self.grpc_use_secure = grpc_use_secure + self.grpc_credentials = grpc_credentials @property def embeddings(self) -> Optional[Embeddings]: return self._embedding + def _get_channel(self) -> Any: + try: + import grpc + except ImportError: + raise ValueError( + "Could not import grpcio python package. " + "Please install it with `pip install grpcio`." + ) + return ( + grpc.secure_channel( + self.target, self.grpc_credentials, options=self.grpc_options + ) + if self.grpc_use_secure + else grpc.insecure_channel(self.target, options=self.grpc_options) + ) + def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, + grpc_metadata: Optional[Any] = None, skip_strict_exist_check: bool = False, **kwargs: Any, ) -> List[str]: @@ -62,7 +83,6 @@ class Vald(VectorStore): skip_strict_exist_check: Deprecated. This is not used basically. """ try: - import grpc from vald.v1.payload import payload_pb2 from vald.v1.vald import upsert_pb2_grpc except ImportError: @@ -71,7 +91,7 @@ class Vald(VectorStore): "Please install it with `pip install vald-client-python`." ) - channel = grpc.insecure_channel(self.target, options=self.grpc_options) + channel = self._get_channel() # Depending on the network quality, # it is necessary to wait for ChannelConnectivity.READY. # _ = grpc.channel_ready_future(channel).result(timeout=10) @@ -82,7 +102,10 @@ class Vald(VectorStore): embs = self._embedding.embed_documents(list(texts)) for text, emb in zip(texts, embs): vec = payload_pb2.Object.Vector(id=text, vector=emb) - res = stub.Upsert(payload_pb2.Upsert.Request(vector=vec, config=cfg)) + res = stub.Upsert( + payload_pb2.Upsert.Request(vector=vec, config=cfg), + metadata=grpc_metadata, + ) ids.append(res.uuid) channel.close() @@ -92,6 +115,7 @@ class Vald(VectorStore): self, ids: Optional[List[str]] = None, skip_strict_exist_check: bool = False, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> Optional[bool]: """ @@ -99,7 +123,6 @@ class Vald(VectorStore): skip_strict_exist_check: Deprecated. This is not used basically. """ try: - import grpc from vald.v1.payload import payload_pb2 from vald.v1.vald import remove_pb2_grpc except ImportError: @@ -111,7 +134,7 @@ class Vald(VectorStore): if ids is None: raise ValueError("No ids provided to delete") - channel = grpc.insecure_channel(self.target, options=self.grpc_options) + channel = self._get_channel() # Depending on the network quality, # it is necessary to wait for ChannelConnectivity.READY. # _ = grpc.channel_ready_future(channel).result(timeout=10) @@ -120,7 +143,9 @@ class Vald(VectorStore): for _id in ids: oid = payload_pb2.Object.ID(id=_id) - _ = stub.Remove(payload_pb2.Remove.Request(id=oid, config=cfg)) + _ = stub.Remove( + payload_pb2.Remove.Request(id=oid, config=cfg), metadata=grpc_metadata + ) channel.close() return True @@ -132,10 +157,11 @@ class Vald(VectorStore): radius: float = -1.0, epsilon: float = 0.01, timeout: int = 3000000000, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = self.similarity_search_with_score( - query, k, radius, epsilon, timeout + query, k, radius, epsilon, timeout, grpc_metadata ) docs = [] @@ -151,11 +177,12 @@ class Vald(VectorStore): radius: float = -1.0, epsilon: float = 0.01, timeout: int = 3000000000, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: emb = self._embedding.embed_query(query) docs_and_scores = self.similarity_search_with_score_by_vector( - emb, k, radius, epsilon, timeout + emb, k, radius, epsilon, timeout, grpc_metadata ) return docs_and_scores @@ -167,10 +194,11 @@ class Vald(VectorStore): radius: float = -1.0, epsilon: float = 0.01, timeout: int = 3000000000, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = self.similarity_search_with_score_by_vector( - embedding, k, radius, epsilon, timeout + embedding, k, radius, epsilon, timeout, grpc_metadata ) docs = [] @@ -186,10 +214,10 @@ class Vald(VectorStore): radius: float = -1.0, epsilon: float = 0.01, timeout: int = 3000000000, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: try: - import grpc from vald.v1.payload import payload_pb2 from vald.v1.vald import search_pb2_grpc except ImportError: @@ -198,7 +226,7 @@ class Vald(VectorStore): "Please install it with `pip install vald-client-python`." ) - channel = grpc.insecure_channel(self.target, options=self.grpc_options) + channel = self._get_channel() # Depending on the network quality, # it is necessary to wait for ChannelConnectivity.READY. # _ = grpc.channel_ready_future(channel).result(timeout=10) @@ -207,7 +235,10 @@ class Vald(VectorStore): num=k, radius=radius, epsilon=epsilon, timeout=timeout ) - res = stub.Search(payload_pb2.Search.Request(vector=embedding, config=cfg)) + res = stub.Search( + payload_pb2.Search.Request(vector=embedding, config=cfg), + metadata=grpc_metadata, + ) docs_and_scores = [] for result in res.results: @@ -225,6 +256,7 @@ class Vald(VectorStore): radius: float = -1.0, epsilon: float = 0.01, timeout: int = 3000000000, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: emb = self._embedding.embed_query(query) @@ -236,6 +268,7 @@ class Vald(VectorStore): epsilon=epsilon, timeout=timeout, lambda_mult=lambda_mult, + grpc_metadata=grpc_metadata, ) return docs @@ -249,10 +282,10 @@ class Vald(VectorStore): radius: float = -1.0, epsilon: float = 0.01, timeout: int = 3000000000, + grpc_metadata: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: try: - import grpc from vald.v1.payload import payload_pb2 from vald.v1.vald import object_pb2_grpc except ImportError: @@ -260,15 +293,19 @@ class Vald(VectorStore): "Could not import vald-client-python python package. " "Please install it with `pip install vald-client-python`." ) - - channel = grpc.insecure_channel(self.target, options=self.grpc_options) + channel = self._get_channel() # Depending on the network quality, # it is necessary to wait for ChannelConnectivity.READY. # _ = grpc.channel_ready_future(channel).result(timeout=10) stub = object_pb2_grpc.ObjectStub(channel) docs_and_scores = self.similarity_search_with_score_by_vector( - embedding, fetch_k=fetch_k, radius=radius, epsilon=epsilon, timeout=timeout + embedding, + fetch_k=fetch_k, + radius=radius, + epsilon=epsilon, + timeout=timeout, + grpc_metadata=grpc_metadata, ) docs = [] @@ -277,7 +314,8 @@ class Vald(VectorStore): vec = stub.GetObject( payload_pb2.Object.VectorRequest( id=payload_pb2.Object.ID(id=doc.page_content) - ) + ), + metadata=grpc_metadata, ) embs.append(vec.vector) docs.append(doc) @@ -304,6 +342,9 @@ class Vald(VectorStore): ("grpc.keepalive_time_ms", 1000 * 10), ("grpc.keepalive_timeout_ms", 1000 * 10), ), + grpc_use_secure: bool = False, + grpc_credentials: Optional[Any] = None, + grpc_metadata: Optional[Any] = None, skip_strict_exist_check: bool = False, **kwargs: Any, ) -> Vald: @@ -316,11 +357,14 @@ class Vald(VectorStore): host=host, port=port, grpc_options=grpc_options, + grpc_use_secure=grpc_use_secure, + grpc_credentials=grpc_credentials, **kwargs, ) vald.add_texts( texts=texts, metadatas=metadatas, + grpc_metadata=grpc_metadata, skip_strict_exist_check=skip_strict_exist_check, ) return vald