mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
Add matching engine vectorstore (#3350)
Co-authored-by: Tom Piaggio <tomaspiaggio@google.com> Co-authored-by: scafati98 <jupyter@matchingengine.us-central1-a.c.scafati-joonix.internal> Co-authored-by: scafati98 <scafatieugenio@gmail.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
8bcaca435a
commit
470b2822a3
346
docs/modules/indexes/vectorstores/examples/matchingengine.ipynb
Normal file
346
docs/modules/indexes/vectorstores/examples/matchingengine.ipynb
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "655b8f55-2089-4733-8b09-35dea9580695",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# MatchingEngine\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use functionality related to the GCP Vertex AI `MatchingEngine` vector database.\n",
|
||||||
|
"\n",
|
||||||
|
"> Vertex AI [Matching Engine](https://cloud.google.com/vertex-ai/docs/matching-engine/overview) provides the industry's leading high-scale low latency vector database. These vector databases are commonly referred to as vector similarity-matching or an approximate nearest neighbor (ANN) service.\n",
|
||||||
|
"\n",
|
||||||
|
"**Note**: This module expects an endpoint and deployed index already created as the creation time takes close to one hour. To see how to create an index refer to the section [Create Index and deploy it to an Endpoint](#create-index-and-deploy-it-to-an-endpoint)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a9971578-0ae9-4809-9e80-e5f9d3dcc98a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create VectorStore from texts"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f7c96da4-8d97-4f69-8c13-d2fcafc03b05",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.vectorstores import MatchingEngine"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "58b70880-edd9-46f3-b769-f26c2bcc8395",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"texts = ['The cat sat on', 'the mat.', 'I like to', 'eat pizza for', 'dinner.', 'The sun sets', 'in the west.']\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"vector_store = MatchingEngine.from_components(\n",
|
||||||
|
" texts=texts,\n",
|
||||||
|
" project_id=\"<my_project_id>\",\n",
|
||||||
|
" region=\"<my_region>\",\n",
|
||||||
|
" gcs_bucket_uri=\"<my_gcs_bucket>\",\n",
|
||||||
|
" index_id=\"<my_matching_engine_index_id>\",\n",
|
||||||
|
" endpoint_id=\"<my_matching_engine_endpoint_id>\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"vector_store.add_texts(texts=texts)\n",
|
||||||
|
"\n",
|
||||||
|
"vector_store.similarity_search(\"lunch\", k=2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0e76e05c-d4ef-49a1-b1b9-2ea989a0eda3",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Create Index and deploy it to an Endpoint"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "61935a91-5efb-48af-bb40-ea1e83e24974",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Imports, Constants and Configs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "421b66c9-5b8f-4ef7-821e-12886a62b672",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Installing dependencies.\n",
|
||||||
|
"!pip install tensorflow \\\n",
|
||||||
|
" google-cloud-aiplatform \\\n",
|
||||||
|
" tensorflow-hub \\\n",
|
||||||
|
" tensorflow-text "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e4e9cc02-371e-40a1-bce9-37ac8efdf2cb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"from google.cloud import aiplatform\n",
|
||||||
|
"import tensorflow_hub as hub\n",
|
||||||
|
"import tensorflow_text"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "352a05df-6532-4aba-a36f-603327a5bc5b",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"PROJECT_ID = \"<my_project_id>\"\n",
|
||||||
|
"REGION = \"<my_region>\"\n",
|
||||||
|
"VPC_NETWORK = \"<my_vpc_network_name>\"\n",
|
||||||
|
"PEERING_RANGE_NAME = \"ann-langchain-me-range\" # Name for creating the VPC peering.\n",
|
||||||
|
"BUCKET_URI = \"gs://<bucket_uri>\"\n",
|
||||||
|
"# The number of dimensions for the tensorflow universal sentence encoder. \n",
|
||||||
|
"# If other embedder is used, the dimensions would probably need to change.\n",
|
||||||
|
"DIMENSIONS = 512\n",
|
||||||
|
"DISPLAY_NAME = \"index-test-name\"\n",
|
||||||
|
"EMBEDDING_DIR = f\"{BUCKET_URI}/banana\"\n",
|
||||||
|
"DEPLOYED_INDEX_ID = \"endpoint-test-name\"\n",
|
||||||
|
"\n",
|
||||||
|
"PROJECT_NUMBER = !gcloud projects list --filter=\"PROJECT_ID:'{PROJECT_ID}'\" --format='value(PROJECT_NUMBER)'\n",
|
||||||
|
"PROJECT_NUMBER = PROJECT_NUMBER[0]\n",
|
||||||
|
"VPC_NETWORK_FULL = f\"projects/{PROJECT_NUMBER}/global/networks/{VPC_NETWORK}\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Change this if you need the VPC to be created.\n",
|
||||||
|
"CREATE_VPC = False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "076e7931-f83e-4597-8748-c8004fd8de96",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Set the project id\n",
|
||||||
|
"! gcloud config set project {PROJECT_ID}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4265081b-a5b7-491e-8ac5-1e26975b9974",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Remove the if condition to run the encapsulated code\n",
|
||||||
|
"if CREATE_VPC:\n",
|
||||||
|
" # Create a VPC network\n",
|
||||||
|
" ! gcloud compute networks create {VPC_NETWORK} --bgp-routing-mode=regional --subnet-mode=auto --project={PROJECT_ID}\n",
|
||||||
|
"\n",
|
||||||
|
" # Add necessary firewall rules\n",
|
||||||
|
" ! gcloud compute firewall-rules create {VPC_NETWORK}-allow-icmp --network {VPC_NETWORK} --priority 65534 --project {PROJECT_ID} --allow icmp\n",
|
||||||
|
"\n",
|
||||||
|
" ! gcloud compute firewall-rules create {VPC_NETWORK}-allow-internal --network {VPC_NETWORK} --priority 65534 --project {PROJECT_ID} --allow all --source-ranges 10.128.0.0/9\n",
|
||||||
|
"\n",
|
||||||
|
" ! gcloud compute firewall-rules create {VPC_NETWORK}-allow-rdp --network {VPC_NETWORK} --priority 65534 --project {PROJECT_ID} --allow tcp:3389\n",
|
||||||
|
"\n",
|
||||||
|
" ! gcloud compute firewall-rules create {VPC_NETWORK}-allow-ssh --network {VPC_NETWORK} --priority 65534 --project {PROJECT_ID} --allow tcp:22\n",
|
||||||
|
"\n",
|
||||||
|
" # Reserve IP range\n",
|
||||||
|
" ! gcloud compute addresses create {PEERING_RANGE_NAME} --global --prefix-length=16 --network={VPC_NETWORK} --purpose=VPC_PEERING --project={PROJECT_ID} --description=\"peering range\"\n",
|
||||||
|
"\n",
|
||||||
|
" # Set up peering with service networking\n",
|
||||||
|
" # Your account must have the \"Compute Network Admin\" role to run the following.\n",
|
||||||
|
" ! gcloud services vpc-peerings connect --service=servicenetworking.googleapis.com --network={VPC_NETWORK} --ranges={PEERING_RANGE_NAME} --project={PROJECT_ID}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9dfbb847-fc53-48c1-b0f2-00d1c4330b01",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Creating bucket.\n",
|
||||||
|
"! gsutil mb -l $REGION -p $PROJECT_ID $BUCKET_URI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "f9698068-3d2f-471b-90c3-dae3e4ca6f63",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Using Tensorflow Universal Sentence Encoder as an Embedder"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "144007e2-ddf8-43cd-ac45-848be0458ba9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the Universal Sentence Encoder module\n",
|
||||||
|
"module_url = \"https://tfhub.dev/google/universal-sentence-encoder-multilingual/3\"\n",
|
||||||
|
"model = hub.load(module_url)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "94a2bdcb-c7e3-4fb0-8c97-cc1f2263f06c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Generate embeddings for each word\n",
|
||||||
|
"embeddings = model(['banana'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5a4e6e99-5e42-4e55-90f6-c03aae4fbf14",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Inserting a test embedding"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "024c78f3-4663-4d8f-9f3c-b7d82073ada4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"initial_config = {\"id\": \"banana_id\", \"embedding\": [float(x) for x in list(embeddings.numpy()[0])]}\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"data.json\", \"w\") as f:\n",
|
||||||
|
" json.dump(initial_config, f)\n",
|
||||||
|
"\n",
|
||||||
|
"!gsutil cp data.json {EMBEDDING_DIR}/file.json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a11489f4-5904-4fc2-9178-f32c2df0406d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e3c6953b-11f6-4803-bf2d-36fa42abf3c7",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Creating Index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c31c3c56-bfe0-49ec-9901-cd146f592da7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(\n",
|
||||||
|
" display_name=DISPLAY_NAME,\n",
|
||||||
|
" contents_delta_uri=EMBEDDING_DIR,\n",
|
||||||
|
" dimensions=DIMENSIONS,\n",
|
||||||
|
" approximate_neighbors_count=150,\n",
|
||||||
|
" distance_measure_type=\"DOT_PRODUCT_DISTANCE\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "50770669-edf6-4796-9563-d1ea59cfa8e8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Creating Endpoint"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "20c93d1b-a7d5-47b0-9c95-1aec1c62e281",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(\n",
|
||||||
|
" display_name=f\"{DISPLAY_NAME}-endpoint\",\n",
|
||||||
|
" network=VPC_NETWORK_FULL,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b52df797-28db-4b4a-b79c-e8a274293a6a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Deploy Index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "019a7043-ad11-4a48-bec7-18928547b2ba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"my_index_endpoint = my_index_endpoint.deploy_index(\n",
|
||||||
|
" index=my_index, \n",
|
||||||
|
" deployed_index_id=DEPLOYED_INDEX_ID\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"my_index_endpoint.deployed_indexes"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"environment": {
|
||||||
|
"kernel": "python3",
|
||||||
|
"name": "common-cpu.m107",
|
||||||
|
"type": "gcloud",
|
||||||
|
"uri": "gcr.io/deeplearning-platform-release/base-cpu:m107"
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
441
langchain/vectorstores/matching_engine.py
Normal file
441
langchain/vectorstores/matching_engine.py
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
"""Vertex Matching Engine implementation of the vector store."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings import TensorflowHubEmbeddings
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from google.cloud import storage
|
||||||
|
from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
|
||||||
|
from google.oauth2.service_account import Credentials
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class MatchingEngine(VectorStore):
|
||||||
|
"""Vertex Matching Engine implementation of the vector store.
|
||||||
|
|
||||||
|
While the embeddings are stored in the Matching Engine, the embedded
|
||||||
|
documents will be stored in GCS.
|
||||||
|
|
||||||
|
An existing Index and corresponding Endpoint are preconditions for
|
||||||
|
using this module.
|
||||||
|
|
||||||
|
See usage in docs/modules/indexes/vectorstores/examples/matchingengine.ipynb
|
||||||
|
|
||||||
|
Note that this implementation is mostly meant for reading if you are
|
||||||
|
planning to do a real time implementation. While reading is a real time
|
||||||
|
operation, updating the index takes close to one hour."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
index: MatchingEngineIndex,
|
||||||
|
endpoint: MatchingEngineIndexEndpoint,
|
||||||
|
embedding: Embeddings,
|
||||||
|
gcs_client: storage.Client,
|
||||||
|
gcs_bucket_name: str,
|
||||||
|
credentials: Optional[Credentials] = None,
|
||||||
|
):
|
||||||
|
"""Vertex Matching Engine implementation of the vector store.
|
||||||
|
|
||||||
|
While the embeddings are stored in the Matching Engine, the embedded
|
||||||
|
documents will be stored in GCS.
|
||||||
|
|
||||||
|
An existing Index and corresponding Endpoint are preconditions for
|
||||||
|
using this module.
|
||||||
|
|
||||||
|
See usage in
|
||||||
|
docs/modules/indexes/vectorstores/examples/matchingengine.ipynb.
|
||||||
|
|
||||||
|
Note that this implementation is mostly meant for reading if you are
|
||||||
|
planning to do a real time implementation. While reading is a real time
|
||||||
|
operation, updating the index takes close to one hour.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
project_id: The GCS project id.
|
||||||
|
index: The created index class. See
|
||||||
|
~:func:`MatchingEngine.from_components`.
|
||||||
|
endpoint: The created endpoint class. See
|
||||||
|
~:func:`MatchingEngine.from_components`.
|
||||||
|
embedding: A :class:`Embeddings` that will be used for
|
||||||
|
embedding the text sent. If none is sent, then the
|
||||||
|
multilingual Tensorflow Universal Sentence Encoder will be used.
|
||||||
|
gcs_client: The GCS client.
|
||||||
|
gcs_bucket_name: The GCS bucket name.
|
||||||
|
credentials (Optional): Created GCP credentials.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._validate_google_libraries_installation()
|
||||||
|
|
||||||
|
self.project_id = project_id
|
||||||
|
self.index = index
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.embedding = embedding
|
||||||
|
self.gcs_client = gcs_client
|
||||||
|
self.credentials = credentials
|
||||||
|
self.gcs_bucket_name = gcs_bucket_name
|
||||||
|
|
||||||
|
def _validate_google_libraries_installation(self) -> None:
|
||||||
|
"""Validates that Google libraries that are needed are installed."""
|
||||||
|
try:
|
||||||
|
from google.cloud import aiplatform, storage # noqa: F401
|
||||||
|
from google.oauth2 import service_account # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"You must run `pip install --upgrade "
|
||||||
|
"google-cloud-aiplatform google-cloud-storage`"
|
||||||
|
"to use the MatchingEngine Vectorstore."
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
kwargs: vectorstore specific parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids from adding the texts into the vectorstore.
|
||||||
|
"""
|
||||||
|
logger.debug("Embedding documents.")
|
||||||
|
embeddings = self.embedding.embed_documents(list(texts))
|
||||||
|
jsons = []
|
||||||
|
ids = []
|
||||||
|
# Could be improved with async.
|
||||||
|
for embedding, text in zip(embeddings, texts):
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
ids.append(id)
|
||||||
|
jsons.append({"id": id, "embedding": embedding})
|
||||||
|
self._upload_to_gcs(text, f"documents/{id}")
|
||||||
|
|
||||||
|
logger.debug(f"Uploaded {len(ids)} documents to GCS.")
|
||||||
|
|
||||||
|
# Creating json lines from the embedded documents.
|
||||||
|
result_str = "\n".join([json.dumps(x) for x in jsons])
|
||||||
|
|
||||||
|
filename_prefix = f"indexes/{uuid.uuid4()}"
|
||||||
|
filename = f"{filename_prefix}/{time.time()}.json"
|
||||||
|
self._upload_to_gcs(result_str, filename)
|
||||||
|
logger.debug(
|
||||||
|
f"Uploaded updated json with embeddings to "
|
||||||
|
f"{self.gcs_bucket_name}/{filename}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.index = self.index.update_embeddings(
|
||||||
|
contents_delta_uri=f"gs://{self.gcs_bucket_name}/{filename_prefix}/"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Updated index with new configuration.")
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def _upload_to_gcs(self, data: str, gcs_location: str) -> None:
|
||||||
|
"""Uploads data to gcs_location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data that will be stored.
|
||||||
|
gcs_location: The location where the data will be stored.
|
||||||
|
"""
|
||||||
|
bucket = self.gcs_client.get_bucket(self.gcs_bucket_name)
|
||||||
|
blob = bucket.blob(gcs_location)
|
||||||
|
blob.upload_from_string(data)
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The string that will be used to search for similar documents.
|
||||||
|
k: The amount of neighbors that will be retrieved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of k matching documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.debug(f"Embedding query {query}.")
|
||||||
|
embedding_query = self.embedding.embed_documents([query])
|
||||||
|
|
||||||
|
response = self.endpoint.match(
|
||||||
|
deployed_index_id=self._get_index_id(),
|
||||||
|
queries=embedding_query,
|
||||||
|
num_neighbors=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(response) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(f"Found {len(response)} matches for the query {query}.")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# I'm only getting the first one because queries receives an array
|
||||||
|
# and the similarity_search method only recevies one query. This
|
||||||
|
# means that the match method will always return an array with only
|
||||||
|
# one element.
|
||||||
|
for doc in response[0]:
|
||||||
|
page_content = self._download_from_gcs(f"documents/{doc.id}")
|
||||||
|
results.append(Document(page_content=page_content))
|
||||||
|
|
||||||
|
logger.debug("Downloaded documents for query.")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _get_index_id(self) -> str:
|
||||||
|
"""Gets the correct index id for the endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The index id if found (which should be found) or throws
|
||||||
|
ValueError otherwise.
|
||||||
|
"""
|
||||||
|
for index in self.endpoint.deployed_indexes:
|
||||||
|
if index.index == self.index.resource_name:
|
||||||
|
return index.id
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"No index with id {self.index.resource_name} "
|
||||||
|
f"deployed on endpoint "
|
||||||
|
f"{self.endpoint.display_name}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_from_gcs(self, gcs_location: str) -> str:
|
||||||
|
"""Downloads from GCS in text format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gcs_location: The location where the file is located.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string contents of the file.
|
||||||
|
"""
|
||||||
|
bucket = self.gcs_client.get_bucket(self.gcs_bucket_name)
|
||||||
|
blob = bucket.blob(gcs_location)
|
||||||
|
return blob.download_as_string()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls: Type["MatchingEngine"],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "MatchingEngine":
|
||||||
|
"""Use from components instead."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This method is not implemented. Instead, you should initialize the class"
|
||||||
|
" with `MatchingEngine.from_components(...)` and then call "
|
||||||
|
"`add_texts`"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_components(
|
||||||
|
cls: Type["MatchingEngine"],
|
||||||
|
project_id: str,
|
||||||
|
region: str,
|
||||||
|
gcs_bucket_name: str,
|
||||||
|
index_id: str,
|
||||||
|
endpoint_id: str,
|
||||||
|
credentials_path: Optional[str] = None,
|
||||||
|
embedding: Optional[Embeddings] = None,
|
||||||
|
) -> "MatchingEngine":
|
||||||
|
"""Takes the object creation out of the constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The GCP project id.
|
||||||
|
region: The default location making the API calls. It must have
|
||||||
|
the same location as the GCS bucket and must be regional.
|
||||||
|
gcs_bucket_name: The location where the vectors will be stored in
|
||||||
|
order for the index to be created.
|
||||||
|
index_id: The id of the created index.
|
||||||
|
endpoint_id: The id of the created endpoint.
|
||||||
|
credentials_path: (Optional) The path of the Google credentials on
|
||||||
|
the local file system.
|
||||||
|
embedding: The :class:`Embeddings` that will be used for
|
||||||
|
embedding the texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured MatchingEngine with the texts added to the index.
|
||||||
|
"""
|
||||||
|
gcs_bucket_name = cls._validate_gcs_bucket(gcs_bucket_name)
|
||||||
|
credentials = cls._create_credentials_from_file(credentials_path)
|
||||||
|
index = cls._create_index_by_id(index_id, project_id, region, credentials)
|
||||||
|
endpoint = cls._create_endpoint_by_id(
|
||||||
|
endpoint_id, project_id, region, credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
gcs_client = cls._get_gcs_client(credentials, project_id)
|
||||||
|
cls._init_aiplatform(project_id, region, gcs_bucket_name, credentials)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
project_id=project_id,
|
||||||
|
index=index,
|
||||||
|
endpoint=endpoint,
|
||||||
|
embedding=embedding or cls._get_default_embeddings(),
|
||||||
|
gcs_client=gcs_client,
|
||||||
|
credentials=credentials,
|
||||||
|
gcs_bucket_name=gcs_bucket_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_gcs_bucket(cls, gcs_bucket_name: str) -> str:
|
||||||
|
"""Validates the gcs_bucket_name as a bucket name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gcs_bucket_name: The received bucket uri.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A valid gcs_bucket_name or throws ValueError if full path is
|
||||||
|
provided.
|
||||||
|
"""
|
||||||
|
gcs_bucket_name = gcs_bucket_name.replace("gs://", "")
|
||||||
|
if "/" in gcs_bucket_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"The argument gcs_bucket_name should only be "
|
||||||
|
f"the bucket name. Received {gcs_bucket_name}"
|
||||||
|
)
|
||||||
|
return gcs_bucket_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_credentials_from_file(
|
||||||
|
cls, json_credentials_path: Optional[str]
|
||||||
|
) -> Optional[Credentials]:
|
||||||
|
"""Creates credentials for GCP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_credentials_path: The path on the file system where the
|
||||||
|
credentials are stored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An optional of Credentials or None, in which case the default
|
||||||
|
will be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
|
||||||
|
credentials = None
|
||||||
|
if json_credentials_path is not None:
|
||||||
|
credentials = service_account.Credentials.from_service_account_file(
|
||||||
|
json_credentials_path
|
||||||
|
)
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_index_by_id(
|
||||||
|
cls, index_id: str, project_id: str, region: str, credentials: "Credentials"
|
||||||
|
) -> MatchingEngineIndex:
|
||||||
|
"""Creates a MatchingEngineIndex object by id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_id: The created index id.
|
||||||
|
project_id: The project to retrieve index from.
|
||||||
|
region: Location to retrieve index from.
|
||||||
|
credentials: GCS credentials.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured MatchingEngineIndex.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
logger.debug(f"Creating matching engine index with id {index_id}.")
|
||||||
|
return aiplatform.MatchingEngineIndex(
|
||||||
|
index_name=index_id,
|
||||||
|
project=project_id,
|
||||||
|
location=region,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_endpoint_by_id(
|
||||||
|
cls, endpoint_id: str, project_id: str, region: str, credentials: "Credentials"
|
||||||
|
) -> MatchingEngineIndexEndpoint:
|
||||||
|
"""Creates a MatchingEngineIndexEndpoint object by id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint_id: The created endpoint id.
|
||||||
|
project_id: The project to retrieve index from.
|
||||||
|
region: Location to retrieve index from.
|
||||||
|
credentials: GCS credentials.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured MatchingEngineIndexEndpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
logger.debug(f"Creating endpoint with id {endpoint_id}.")
|
||||||
|
return aiplatform.MatchingEngineIndexEndpoint(
|
||||||
|
index_endpoint_name=endpoint_id,
|
||||||
|
project=project_id,
|
||||||
|
location=region,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_gcs_client(
|
||||||
|
cls, credentials: "Credentials", project_id: str
|
||||||
|
) -> "storage.Client":
|
||||||
|
"""Lazily creates a GCS client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured GCS client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from google.cloud import storage
|
||||||
|
|
||||||
|
return storage.Client(credentials=credentials, project=project_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_aiplatform(
|
||||||
|
cls,
|
||||||
|
project_id: str,
|
||||||
|
region: str,
|
||||||
|
gcs_bucket_name: str,
|
||||||
|
credentials: "Credentials",
|
||||||
|
) -> None:
|
||||||
|
"""Configures the aiplatform library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The GCP project id.
|
||||||
|
region: The default location making the API calls. It must have
|
||||||
|
the same location as the GCS bucket and must be regional.
|
||||||
|
gcs_bucket_name: GCS staging location.
|
||||||
|
credentials: The GCS Credentials object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Initializing AI Platform for project {project_id} on "
|
||||||
|
f"{region} and for {gcs_bucket_name}."
|
||||||
|
)
|
||||||
|
aiplatform.init(
|
||||||
|
project=project_id,
|
||||||
|
location=region,
|
||||||
|
staging_bucket=gcs_bucket_name,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_default_embeddings(cls) -> TensorflowHubEmbeddings:
|
||||||
|
"""This function returns the default embedding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default TensorflowHubEmbeddings to use.
|
||||||
|
"""
|
||||||
|
return TensorflowHubEmbeddings()
|
Loading…
Reference in New Issue
Block a user