Improvements to the Clarifai integration (#9290)

- Improved docs
- Improved performance in multiple ways through batching, threading,
etc.
 - fixed error message 
 - Added support for metadata filtering during similarity search.

@baskaryan PTAL
This commit is contained in:
Matthew Zeiler
2023-08-21 15:53:36 -04:00
committed by GitHub
parent 66a47d9a61
commit 949b2cf177
6 changed files with 261 additions and 100 deletions

View File

@@ -103,37 +103,44 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install clarifai`."
)
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in texts
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
batch_size = 32
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0])
else None
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in batch
],
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
)
embeddings.extend(
[
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
)
embeddings = [
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
return embeddings
def embed_query(self, text: str) -> List[float]:

View File

@@ -5,6 +5,7 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
@@ -163,7 +164,7 @@ class Clarifai(LLM):
logger.error(post_model_outputs_response.status)
first_model_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0])
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
@@ -178,3 +179,67 @@ class Clarifai(LLM):
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from clarifai_grpc.grpc.api import (
resources_pb2,
service_pb2,
)
from clarifai_grpc.grpc.api.status import status_code_pb2
except ImportError:
raise ImportError(
"Could not import clarifai python package. "
"Please install it with `pip install clarifai`."
)
# TODO: add caching here.
generations = []
batch_size = 32
for i in range(0, len(prompts), batch_size):
batch = prompts[i : i + batch_size]
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=prompt))
)
for prompt in batch
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_model_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_model_failure}"
)
for output in post_model_outputs_response.outputs:
if stop is not None:
text = enforce_stop_tokens(output.data.text.raw, stop)
else:
text = output.data.text.raw
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Iterable, List, Optional, Tuple
import requests
@@ -84,7 +85,9 @@ class Clarifai(VectorStore):
self._userDataObject = self._auth.get_user_app_id_proto()
self._number_of_docs = number_of_docs
def _post_text_input(self, text: str, metadata: dict) -> str:
def _post_texts_as_inputs(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[str]:
"""Post text to Clarifai and return the ID of the input.
Args:
@@ -104,20 +107,29 @@ class Clarifai(VectorStore):
"Please install it with `pip install clarifai`."
) from e
input_metadata = Struct()
input_metadata.update(metadata)
if metadatas is not None:
assert len(list(texts)) == len(
metadatas
), "Number of texts and metadatas should be the same."
inputs = []
for idx, text in enumerate(texts):
if metadatas is not None:
input_metadata = Struct()
input_metadata.update(metadatas[idx])
inputs.append(
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=text),
metadata=input_metadata,
)
)
)
post_inputs_response = self._stub.PostInputs(
service_pb2.PostInputsRequest(
user_app_id=self._userDataObject,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=text),
metadata=input_metadata,
)
)
],
inputs=inputs,
)
)
@@ -127,9 +139,11 @@ class Clarifai(VectorStore):
"Post inputs failed, status: " + post_inputs_response.status.description
)
input_id = post_inputs_response.inputs[0].id
input_ids = []
for input in post_inputs_response.inputs:
input_ids.append(input.id)
return input_id
return input_ids
def add_texts(
self,
@@ -140,7 +154,7 @@ class Clarifai(VectorStore):
) -> List[str]:
"""Add texts to the Clarifai vectorstore. This will push the text
to a Clarifai application.
Application use base workflow that create and store embedding for each text.
Application use a base workflow that create and store embedding for each text.
Make sure you are using a base workflow that is compatible with text
(such as Language Understanding).
@@ -153,20 +167,26 @@ class Clarifai(VectorStore):
List[str]: List of IDs of the added texts.
"""
assert len(list(texts)) > 0, "No texts provided to add to the vectorstore."
ltexts = list(texts)
length = len(ltexts)
assert length > 0, "No texts provided to add to the vectorstore."
if metadatas is not None:
assert len(list(texts)) == len(
assert length == len(
metadatas
), "Number of texts and metadatas should be the same."
batch_size = 32
input_ids = []
for idx, text in enumerate(texts):
for idx in range(0, length, batch_size):
try:
metadata = metadatas[idx] if metadatas else {}
input_id = self._post_text_input(text, metadata)
input_ids.append(input_id)
logger.debug(f"Input {input_id} posted successfully.")
batch_texts = ltexts[idx : idx + batch_size]
batch_metadatas = (
metadatas[idx : idx + batch_size] if metadatas else None
)
result_ids = self._post_texts_as_inputs(batch_texts, batch_metadatas)
input_ids.extend(result_ids)
logger.debug(f"Input {result_ids} posted successfully.")
except Exception as error:
logger.warning(f"Post inputs failed: {error}")
traceback.print_exc()
@@ -196,6 +216,7 @@ class Clarifai(VectorStore):
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Struct # type: ignore
except ImportError as e:
raise ImportError(
"Could not import clarifai python package. "
@@ -206,28 +227,35 @@ class Clarifai(VectorStore):
if self._number_of_docs is not None:
k = self._number_of_docs
post_annotations_searches_response = self._stub.PostAnnotationsSearches(
service_pb2.PostAnnotationsSearchesRequest(
user_app_id=self._userDataObject,
searches=[
resources_pb2.Search(
query=resources_pb2.Query(
ranks=[
resources_pb2.Rank(
annotation=resources_pb2.Annotation(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=query),
)
req = service_pb2.PostAnnotationsSearchesRequest(
user_app_id=self._userDataObject,
searches=[
resources_pb2.Search(
query=resources_pb2.Query(
ranks=[
resources_pb2.Rank(
annotation=resources_pb2.Annotation(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=query),
)
)
]
)
)
]
)
],
pagination=service_pb2.Pagination(page=1, per_page=k),
)
)
],
pagination=service_pb2.Pagination(page=1, per_page=k),
)
# Add filter by metadata if provided.
if filter is not None:
search_metadata = Struct()
search_metadata.update(filter)
f = req.searches[0].query.filters.add()
f.annotation.data.metadata.update(search_metadata)
post_annotations_searches_response = self._stub.PostAnnotationsSearches(req)
# Check if search was successful
if post_annotations_searches_response.status.code != status_code_pb2.SUCCESS:
raise Exception(
@@ -238,11 +266,12 @@ class Clarifai(VectorStore):
# Retrieve hits
hits = post_annotations_searches_response.hits
docs_and_scores = []
# Iterate over hits and retrieve metadata and text
for hit in hits:
executor = ThreadPoolExecutor(max_workers=10)
def hit_to_document(hit: resources_pb2.Hit) -> Tuple[Document, float]:
metadata = json_format.MessageToDict(hit.input.data.metadata)
request = requests.get(hit.input.data.text.url)
h = {"Authorization": f"Key {self._auth.pat}"}
request = requests.get(hit.input.data.text.url, headers=h)
# override encoding by real educated guess as provided by chardet
request.encoding = request.apparent_encoding
@@ -252,10 +281,11 @@ class Clarifai(VectorStore):
f"\tScore {hit.score:.2f} for annotation: {hit.annotation.id}\
off input: {hit.input.id}, text: {requested_text[:125]}"
)
return (Document(page_content=requested_text, metadata=metadata), hit.score)
docs_and_scores.append(
(Document(page_content=requested_text, metadata=metadata), hit.score)
)
# Iterate over hits and retrieve metadata and text
futures = [executor.submit(hit_to_document, hit) for hit in hits]
docs_and_scores = [future.result() for future in futures]
return docs_and_scores