Improvements to the Clarifai integration (#9290)

- Improved docs
- Improved performance in multiple ways through batching, threading,
etc.
 - fixed error message 
 - Added support for metadata filtering during similarity search.

@baskaryan PTAL
This commit is contained in:
Matthew Zeiler 2023-08-21 15:53:36 -04:00 committed by GitHub
parent 66a47d9a61
commit 949b2cf177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 261 additions and 100 deletions

View File

@ -37,7 +37,7 @@ There is a Clarifai Embedding model in LangChain, which you can access with:
from langchain.embeddings import ClarifaiEmbeddings from langchain.embeddings import ClarifaiEmbeddings
embeddings = ClarifaiEmbeddings(pat=CLARIFAI_PAT, user_id=USER_ID, app_id=APP_ID, model_id=MODEL_ID) embeddings = ClarifaiEmbeddings(pat=CLARIFAI_PAT, user_id=USER_ID, app_id=APP_ID, model_id=MODEL_ID)
``` ```
For more details, the docs on the Clarifai Embeddings wrapper provide a [detailed walthrough](/docs/integrations/text_embedding/clarifai.html). For more details, the docs on the Clarifai Embeddings wrapper provide a [detailed walkthrough](/docs/integrations/text_embedding/clarifai.html).
## Vectorstore ## Vectorstore
@ -49,4 +49,4 @@ You an also add data directly from LangChain as well, and the auto-indexing will
from langchain.vectorstores import Clarifai from langchain.vectorstores import Clarifai
clarifai_vector_db = Clarifai.from_texts(user_id=USER_ID, app_id=APP_ID, texts=texts, pat=CLARIFAI_PAT, number_of_docs=NUMBER_OF_DOCS, metadatas = metadatas) clarifai_vector_db = Clarifai.from_texts(user_id=USER_ID, app_id=APP_ID, texts=texts, pat=CLARIFAI_PAT, number_of_docs=NUMBER_OF_DOCS, metadatas = metadatas)
``` ```
For more details, the docs on the Clarifai vector store provide a [detailed walthrough](/docs/integrations/text_embedding/clarifai.html). For more details, the docs on the Clarifai vector store provide a [detailed walkthrough](/docs/integrations/text_embedding/clarifai.html).

View File

@ -130,9 +130,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"USER_ID = \"openai\"\n", "USER_ID = \"salesforce\"\n",
"APP_ID = \"embed\"\n", "APP_ID = \"blip\"\n",
"MODEL_ID = \"text-embedding-ada\"\n", "MODEL_ID = \"multimodal-embedder-blip-2\"\n",
"\n", "\n",
"# You can provide a specific model version as the model_version_id arg.\n", "# You can provide a specific model version as the model_version_id arg.\n",
"# MODEL_VERSION_ID = \"MODEL_VERSION_ID\"" "# MODEL_VERSION_ID = \"MODEL_VERSION_ID\""

View File

@ -53,7 +53,15 @@
"execution_count": 1, "execution_count": 1,
"id": "c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87", "id": "c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [ "source": [
"# Please login and get your API key from https://clarifai.com/settings/security\n", "# Please login and get your API key from https://clarifai.com/settings/security\n",
"from getpass import getpass\n", "from getpass import getpass\n",
@ -61,18 +69,9 @@
"CLARIFAI_PAT = getpass()" "CLARIFAI_PAT = getpass()"
] ]
}, },
{
"attachments": {},
"cell_type": "markdown",
"id": "320af802-9271-46ee-948f-d2453933d44b",
"metadata": {},
"source": [
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 6,
"id": "aac9563e", "id": "aac9563e",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -99,7 +98,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "4d853395", "id": "4d853395",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -134,7 +133,7 @@
" \"I love playing soccer with my friends\",\n", " \"I love playing soccer with my friends\",\n",
"]\n", "]\n",
"\n", "\n",
"metadatas = [{\"id\": i, \"text\": text} for i, text in enumerate(texts)]" "metadatas = [{\"id\": i, \"text\": text, \"source\": \"book 1\", \"category\": [\"books\", \"modern\"]} for i, text in enumerate(texts)]"
] ]
}, },
{ {
@ -156,21 +155,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"id": "e755cdce", "id": "e755cdce",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[Document(page_content='I really enjoy spending time with you', metadata={'text': 'I really enjoy spending time with you', 'id': 0.0}),\n", "[Document(page_content='I really enjoy spending time with you', metadata={'text': 'I really enjoy spending time with you', 'id': 0.0, 'source': 'book 1', 'category': ['books', 'modern']}),\n",
" Document(page_content='I went to the movies yesterday', metadata={'text': 'I went to the movies yesterday', 'id': 3.0}),\n", " Document(page_content='I went to the movies yesterday', metadata={'text': 'I went to the movies yesterday', 'id': 3.0, 'source': 'book 1', 'category': ['books', 'modern']})]"
" Document(page_content='zab', metadata={'page': '2'}),\n",
" Document(page_content='zab', metadata={'page': '2'})]"
] ]
}, },
"execution_count": 7,
"metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
@ -179,6 +174,21 @@
"docs" "docs"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "140103ec-0936-454a-9f4a-7d5beefc138f",
"metadata": {},
"outputs": [],
"source": [
"# There is lots powerful filtering you can do within an app by leveraging metadata filters. \n",
"# This one will limit the similarity query to only the texts that have key of \"source\" matching value of \"book 1\"\n",
"book1_similar_docs = clarifai_vector_db.similarity_search(\"I would love to see you\", filter={\"source\": \"book 1\"})\n",
"\n",
"# you can also use lists in the input's metadata and then select things that match an item in the list. This is useful for categories like below:\n",
"book_category_similar_docs = clarifai_vector_db.similarity_search(\"I would love to see you\", filter={\"category\": [\"books\"]})"
]
},
{ {
"attachments": {}, "attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
@ -249,7 +259,7 @@
" user_id=USER_ID,\n", " user_id=USER_ID,\n",
" app_id=APP_ID,\n", " app_id=APP_ID,\n",
" documents=docs,\n", " documents=docs,\n",
" pat=CLARIFAI_PAT_KEY,\n", " pat=CLARIFAI_PAT,\n",
" number_of_docs=NUMBER_OF_DOCS,\n", " number_of_docs=NUMBER_OF_DOCS,\n",
")" ")"
] ]
@ -278,6 +288,55 @@
"docs = clarifai_vector_db.similarity_search(\"Texts related to criminals and violence\")\n", "docs = clarifai_vector_db.similarity_search(\"Texts related to criminals and violence\")\n",
"docs" "docs"
] ]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7b332ca4-416b-4ea6-99da-b6949f399d72",
"metadata": {},
"source": [
"## From existing App\n",
"Within Clarifai we have great tools for adding data to applications (essentially projects) via API or UI. Most users will already have done that before interacting with LangChain so this example will use the data in an existing app to perform searches. Check out our [API docs](https://docs.clarifai.com/api-guide/data/create-get-update-delete) and [UI docs](https://docs.clarifai.com/portal-guide/data). The Clarifai Application can then be used for semantic search to find relevant documents."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "807c1141-591b-436d-abaa-f2c325e66d39",
"metadata": {},
"outputs": [],
"source": [
"USER_ID = \"USERNAME_ID\"\n",
"APP_ID = \"APPLICATION_ID\"\n",
"NUMBER_OF_DOCS = 4"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "762d74ef-f7df-43d6-b121-4980c4059fc0",
"metadata": {},
"outputs": [],
"source": [
"clarifai_vector_db = Clarifai(\n",
" user_id=USER_ID,\n",
" app_id=APP_ID,\n",
" documents=docs,\n",
" pat=CLARIFAI_PAT,\n",
" number_of_docs=NUMBER_OF_DOCS,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7636b0f-68ab-4b8f-ba0f-3c27061e3631",
"metadata": {},
"outputs": [],
"source": [
"docs = clarifai_vector_db.similarity_search(\"Texts related to criminals and violence\")\n",
"docs"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -103,37 +103,44 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install clarifai`." "Please install it with `pip install clarifai`."
) )
post_model_outputs_request = service_pb2.PostModelOutputsRequest( batch_size = 32
user_app_id=self.userDataObject, embeddings = []
model_id=self.model_id, for i in range(0, len(texts), batch_size):
version_id=self.model_version_id, batch = texts[i : i + batch_size]
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in texts
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS: post_model_outputs_request = service_pb2.PostModelOutputsRequest(
logger.error(post_model_outputs_response.status) user_app_id=self.userDataObject,
first_output_failure = ( model_id=self.model_id,
post_model_outputs_response.outputs[0].status version_id=self.model_version_id,
if len(post_model_outputs_response.outputs[0]) inputs=[
else None resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in batch
],
) )
raise Exception( post_model_outputs_response = self.stub.PostModelOutputs(
f"Post model outputs failed, status: " post_model_outputs_request
f"{post_model_outputs_response.status}, first output failure: " )
f"{first_output_failure}"
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
)
embeddings.extend(
[
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
) )
embeddings = [
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
return embeddings return embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:

View File

@ -5,6 +5,7 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -163,7 +164,7 @@ class Clarifai(LLM):
logger.error(post_model_outputs_response.status) logger.error(post_model_outputs_response.status)
first_model_failure = ( first_model_failure = (
post_model_outputs_response.outputs[0].status post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0]) if len(post_model_outputs_response.outputs)
else None else None
) )
raise Exception( raise Exception(
@ -178,3 +179,67 @@ class Clarifai(LLM):
if stop is not None: if stop is not None:
text = enforce_stop_tokens(text, stop) text = enforce_stop_tokens(text, stop)
return text return text
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from clarifai_grpc.grpc.api import (
resources_pb2,
service_pb2,
)
from clarifai_grpc.grpc.api.status import status_code_pb2
except ImportError:
raise ImportError(
"Could not import clarifai python package. "
"Please install it with `pip install clarifai`."
)
# TODO: add caching here.
generations = []
batch_size = 32
for i in range(0, len(prompts), batch_size):
batch = prompts[i : i + batch_size]
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=prompt))
)
for prompt in batch
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_model_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_model_failure}"
)
for output in post_model_outputs_response.outputs:
if stop is not None:
text = enforce_stop_tokens(output.data.text.raw, stop)
else:
text = output.data.text.raw
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import os import os
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
import requests import requests
@ -84,7 +85,9 @@ class Clarifai(VectorStore):
self._userDataObject = self._auth.get_user_app_id_proto() self._userDataObject = self._auth.get_user_app_id_proto()
self._number_of_docs = number_of_docs self._number_of_docs = number_of_docs
def _post_text_input(self, text: str, metadata: dict) -> str: def _post_texts_as_inputs(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[str]:
"""Post text to Clarifai and return the ID of the input. """Post text to Clarifai and return the ID of the input.
Args: Args:
@ -104,20 +107,29 @@ class Clarifai(VectorStore):
"Please install it with `pip install clarifai`." "Please install it with `pip install clarifai`."
) from e ) from e
input_metadata = Struct() if metadatas is not None:
input_metadata.update(metadata) assert len(list(texts)) == len(
metadatas
), "Number of texts and metadatas should be the same."
inputs = []
for idx, text in enumerate(texts):
if metadatas is not None:
input_metadata = Struct()
input_metadata.update(metadatas[idx])
inputs.append(
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=text),
metadata=input_metadata,
)
)
)
post_inputs_response = self._stub.PostInputs( post_inputs_response = self._stub.PostInputs(
service_pb2.PostInputsRequest( service_pb2.PostInputsRequest(
user_app_id=self._userDataObject, user_app_id=self._userDataObject,
inputs=[ inputs=inputs,
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=text),
metadata=input_metadata,
)
)
],
) )
) )
@ -127,9 +139,11 @@ class Clarifai(VectorStore):
"Post inputs failed, status: " + post_inputs_response.status.description "Post inputs failed, status: " + post_inputs_response.status.description
) )
input_id = post_inputs_response.inputs[0].id input_ids = []
for input in post_inputs_response.inputs:
input_ids.append(input.id)
return input_id return input_ids
def add_texts( def add_texts(
self, self,
@ -140,7 +154,7 @@ class Clarifai(VectorStore):
) -> List[str]: ) -> List[str]:
"""Add texts to the Clarifai vectorstore. This will push the text """Add texts to the Clarifai vectorstore. This will push the text
to a Clarifai application. to a Clarifai application.
Application use base workflow that create and store embedding for each text. Application use a base workflow that create and store embedding for each text.
Make sure you are using a base workflow that is compatible with text Make sure you are using a base workflow that is compatible with text
(such as Language Understanding). (such as Language Understanding).
@ -153,20 +167,26 @@ class Clarifai(VectorStore):
List[str]: List of IDs of the added texts. List[str]: List of IDs of the added texts.
""" """
assert len(list(texts)) > 0, "No texts provided to add to the vectorstore." ltexts = list(texts)
length = len(ltexts)
assert length > 0, "No texts provided to add to the vectorstore."
if metadatas is not None: if metadatas is not None:
assert len(list(texts)) == len( assert length == len(
metadatas metadatas
), "Number of texts and metadatas should be the same." ), "Number of texts and metadatas should be the same."
batch_size = 32
input_ids = [] input_ids = []
for idx, text in enumerate(texts): for idx in range(0, length, batch_size):
try: try:
metadata = metadatas[idx] if metadatas else {} batch_texts = ltexts[idx : idx + batch_size]
input_id = self._post_text_input(text, metadata) batch_metadatas = (
input_ids.append(input_id) metadatas[idx : idx + batch_size] if metadatas else None
logger.debug(f"Input {input_id} posted successfully.") )
result_ids = self._post_texts_as_inputs(batch_texts, batch_metadatas)
input_ids.extend(result_ids)
logger.debug(f"Input {result_ids} posted successfully.")
except Exception as error: except Exception as error:
logger.warning(f"Post inputs failed: {error}") logger.warning(f"Post inputs failed: {error}")
traceback.print_exc() traceback.print_exc()
@ -196,6 +216,7 @@ class Clarifai(VectorStore):
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2 from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf import json_format # type: ignore from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Struct # type: ignore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Could not import clarifai python package. " "Could not import clarifai python package. "
@ -206,28 +227,35 @@ class Clarifai(VectorStore):
if self._number_of_docs is not None: if self._number_of_docs is not None:
k = self._number_of_docs k = self._number_of_docs
post_annotations_searches_response = self._stub.PostAnnotationsSearches( req = service_pb2.PostAnnotationsSearchesRequest(
service_pb2.PostAnnotationsSearchesRequest( user_app_id=self._userDataObject,
user_app_id=self._userDataObject, searches=[
searches=[ resources_pb2.Search(
resources_pb2.Search( query=resources_pb2.Query(
query=resources_pb2.Query( ranks=[
ranks=[ resources_pb2.Rank(
resources_pb2.Rank( annotation=resources_pb2.Annotation(
annotation=resources_pb2.Annotation( data=resources_pb2.Data(
data=resources_pb2.Data( text=resources_pb2.Text(raw=query),
text=resources_pb2.Text(raw=query),
)
) )
) )
] )
) ]
) )
], )
pagination=service_pb2.Pagination(page=1, per_page=k), ],
) pagination=service_pb2.Pagination(page=1, per_page=k),
) )
# Add filter by metadata if provided.
if filter is not None:
search_metadata = Struct()
search_metadata.update(filter)
f = req.searches[0].query.filters.add()
f.annotation.data.metadata.update(search_metadata)
post_annotations_searches_response = self._stub.PostAnnotationsSearches(req)
# Check if search was successful # Check if search was successful
if post_annotations_searches_response.status.code != status_code_pb2.SUCCESS: if post_annotations_searches_response.status.code != status_code_pb2.SUCCESS:
raise Exception( raise Exception(
@ -238,11 +266,12 @@ class Clarifai(VectorStore):
# Retrieve hits # Retrieve hits
hits = post_annotations_searches_response.hits hits = post_annotations_searches_response.hits
docs_and_scores = [] executor = ThreadPoolExecutor(max_workers=10)
# Iterate over hits and retrieve metadata and text
for hit in hits: def hit_to_document(hit: resources_pb2.Hit) -> Tuple[Document, float]:
metadata = json_format.MessageToDict(hit.input.data.metadata) metadata = json_format.MessageToDict(hit.input.data.metadata)
request = requests.get(hit.input.data.text.url) h = {"Authorization": f"Key {self._auth.pat}"}
request = requests.get(hit.input.data.text.url, headers=h)
# override encoding by real educated guess as provided by chardet # override encoding by real educated guess as provided by chardet
request.encoding = request.apparent_encoding request.encoding = request.apparent_encoding
@ -252,10 +281,11 @@ class Clarifai(VectorStore):
f"\tScore {hit.score:.2f} for annotation: {hit.annotation.id}\ f"\tScore {hit.score:.2f} for annotation: {hit.annotation.id}\
off input: {hit.input.id}, text: {requested_text[:125]}" off input: {hit.input.id}, text: {requested_text[:125]}"
) )
return (Document(page_content=requested_text, metadata=metadata), hit.score)
docs_and_scores.append( # Iterate over hits and retrieve metadata and text
(Document(page_content=requested_text, metadata=metadata), hit.score) futures = [executor.submit(hit_to_document, hit) for hit in hits]
) docs_and_scores = [future.result() for future in futures]
return docs_and_scores return docs_and_scores