mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
Improvements to the Clarifai integration (#9290)
- Improved docs - Improved performance in multiple ways through batching, threading, etc. - fixed error message - Added support for metadata filtering during similarity search. @baskaryan PTAL
This commit is contained in:
@@ -103,37 +103,44 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
|
||||
)
|
||||
for t in texts
|
||||
],
|
||||
)
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
batch_size = 32
|
||||
embeddings = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_output_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs[0])
|
||||
else None
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
|
||||
)
|
||||
for t in batch
|
||||
],
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_output_failure}"
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_output_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs)
|
||||
else None
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_output_failure}"
|
||||
)
|
||||
embeddings.extend(
|
||||
[
|
||||
list(o.data.embeddings[0].vector)
|
||||
for o in post_model_outputs_response.outputs
|
||||
]
|
||||
)
|
||||
embeddings = [
|
||||
list(o.data.embeddings[0].vector)
|
||||
for o in post_model_outputs_response.outputs
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
@@ -5,6 +5,7 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -163,7 +164,7 @@ class Clarifai(LLM):
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_model_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs[0])
|
||||
if len(post_model_outputs_response.outputs)
|
||||
else None
|
||||
)
|
||||
raise Exception(
|
||||
@@ -178,3 +179,67 @@ class Clarifai(LLM):
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import (
|
||||
resources_pb2,
|
||||
service_pb2,
|
||||
)
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
batch_size = 32
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
batch = prompts[i : i + batch_size]
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=prompt))
|
||||
)
|
||||
for prompt in batch
|
||||
],
|
||||
)
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_model_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs)
|
||||
else None
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_model_failure}"
|
||||
)
|
||||
|
||||
for output in post_model_outputs_response.outputs:
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(output.data.text.raw, stop)
|
||||
else:
|
||||
text = output.data.text.raw
|
||||
|
||||
generations.append([Generation(text=text)])
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
@@ -84,7 +85,9 @@ class Clarifai(VectorStore):
|
||||
self._userDataObject = self._auth.get_user_app_id_proto()
|
||||
self._number_of_docs = number_of_docs
|
||||
|
||||
def _post_text_input(self, text: str, metadata: dict) -> str:
|
||||
def _post_texts_as_inputs(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> List[str]:
|
||||
"""Post text to Clarifai and return the ID of the input.
|
||||
|
||||
Args:
|
||||
@@ -104,20 +107,29 @@ class Clarifai(VectorStore):
|
||||
"Please install it with `pip install clarifai`."
|
||||
) from e
|
||||
|
||||
input_metadata = Struct()
|
||||
input_metadata.update(metadata)
|
||||
if metadatas is not None:
|
||||
assert len(list(texts)) == len(
|
||||
metadatas
|
||||
), "Number of texts and metadatas should be the same."
|
||||
|
||||
inputs = []
|
||||
for idx, text in enumerate(texts):
|
||||
if metadatas is not None:
|
||||
input_metadata = Struct()
|
||||
input_metadata.update(metadatas[idx])
|
||||
inputs.append(
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(
|
||||
text=resources_pb2.Text(raw=text),
|
||||
metadata=input_metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
post_inputs_response = self._stub.PostInputs(
|
||||
service_pb2.PostInputsRequest(
|
||||
user_app_id=self._userDataObject,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(
|
||||
text=resources_pb2.Text(raw=text),
|
||||
metadata=input_metadata,
|
||||
)
|
||||
)
|
||||
],
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -127,9 +139,11 @@ class Clarifai(VectorStore):
|
||||
"Post inputs failed, status: " + post_inputs_response.status.description
|
||||
)
|
||||
|
||||
input_id = post_inputs_response.inputs[0].id
|
||||
input_ids = []
|
||||
for input in post_inputs_response.inputs:
|
||||
input_ids.append(input.id)
|
||||
|
||||
return input_id
|
||||
return input_ids
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
@@ -140,7 +154,7 @@ class Clarifai(VectorStore):
|
||||
) -> List[str]:
|
||||
"""Add texts to the Clarifai vectorstore. This will push the text
|
||||
to a Clarifai application.
|
||||
Application use base workflow that create and store embedding for each text.
|
||||
Application use a base workflow that create and store embedding for each text.
|
||||
Make sure you are using a base workflow that is compatible with text
|
||||
(such as Language Understanding).
|
||||
|
||||
@@ -153,20 +167,26 @@ class Clarifai(VectorStore):
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
|
||||
assert len(list(texts)) > 0, "No texts provided to add to the vectorstore."
|
||||
ltexts = list(texts)
|
||||
length = len(ltexts)
|
||||
assert length > 0, "No texts provided to add to the vectorstore."
|
||||
|
||||
if metadatas is not None:
|
||||
assert len(list(texts)) == len(
|
||||
assert length == len(
|
||||
metadatas
|
||||
), "Number of texts and metadatas should be the same."
|
||||
|
||||
batch_size = 32
|
||||
input_ids = []
|
||||
for idx, text in enumerate(texts):
|
||||
for idx in range(0, length, batch_size):
|
||||
try:
|
||||
metadata = metadatas[idx] if metadatas else {}
|
||||
input_id = self._post_text_input(text, metadata)
|
||||
input_ids.append(input_id)
|
||||
logger.debug(f"Input {input_id} posted successfully.")
|
||||
batch_texts = ltexts[idx : idx + batch_size]
|
||||
batch_metadatas = (
|
||||
metadatas[idx : idx + batch_size] if metadatas else None
|
||||
)
|
||||
result_ids = self._post_texts_as_inputs(batch_texts, batch_metadatas)
|
||||
input_ids.extend(result_ids)
|
||||
logger.debug(f"Input {result_ids} posted successfully.")
|
||||
except Exception as error:
|
||||
logger.warning(f"Post inputs failed: {error}")
|
||||
traceback.print_exc()
|
||||
@@ -196,6 +216,7 @@ class Clarifai(VectorStore):
|
||||
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from google.protobuf import json_format # type: ignore
|
||||
from google.protobuf.struct_pb2 import Struct # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
@@ -206,28 +227,35 @@ class Clarifai(VectorStore):
|
||||
if self._number_of_docs is not None:
|
||||
k = self._number_of_docs
|
||||
|
||||
post_annotations_searches_response = self._stub.PostAnnotationsSearches(
|
||||
service_pb2.PostAnnotationsSearchesRequest(
|
||||
user_app_id=self._userDataObject,
|
||||
searches=[
|
||||
resources_pb2.Search(
|
||||
query=resources_pb2.Query(
|
||||
ranks=[
|
||||
resources_pb2.Rank(
|
||||
annotation=resources_pb2.Annotation(
|
||||
data=resources_pb2.Data(
|
||||
text=resources_pb2.Text(raw=query),
|
||||
)
|
||||
req = service_pb2.PostAnnotationsSearchesRequest(
|
||||
user_app_id=self._userDataObject,
|
||||
searches=[
|
||||
resources_pb2.Search(
|
||||
query=resources_pb2.Query(
|
||||
ranks=[
|
||||
resources_pb2.Rank(
|
||||
annotation=resources_pb2.Annotation(
|
||||
data=resources_pb2.Data(
|
||||
text=resources_pb2.Text(raw=query),
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
pagination=service_pb2.Pagination(page=1, per_page=k),
|
||||
)
|
||||
)
|
||||
],
|
||||
pagination=service_pb2.Pagination(page=1, per_page=k),
|
||||
)
|
||||
|
||||
# Add filter by metadata if provided.
|
||||
if filter is not None:
|
||||
search_metadata = Struct()
|
||||
search_metadata.update(filter)
|
||||
f = req.searches[0].query.filters.add()
|
||||
f.annotation.data.metadata.update(search_metadata)
|
||||
|
||||
post_annotations_searches_response = self._stub.PostAnnotationsSearches(req)
|
||||
|
||||
# Check if search was successful
|
||||
if post_annotations_searches_response.status.code != status_code_pb2.SUCCESS:
|
||||
raise Exception(
|
||||
@@ -238,11 +266,12 @@ class Clarifai(VectorStore):
|
||||
# Retrieve hits
|
||||
hits = post_annotations_searches_response.hits
|
||||
|
||||
docs_and_scores = []
|
||||
# Iterate over hits and retrieve metadata and text
|
||||
for hit in hits:
|
||||
executor = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def hit_to_document(hit: resources_pb2.Hit) -> Tuple[Document, float]:
|
||||
metadata = json_format.MessageToDict(hit.input.data.metadata)
|
||||
request = requests.get(hit.input.data.text.url)
|
||||
h = {"Authorization": f"Key {self._auth.pat}"}
|
||||
request = requests.get(hit.input.data.text.url, headers=h)
|
||||
|
||||
# override encoding by real educated guess as provided by chardet
|
||||
request.encoding = request.apparent_encoding
|
||||
@@ -252,10 +281,11 @@ class Clarifai(VectorStore):
|
||||
f"\tScore {hit.score:.2f} for annotation: {hit.annotation.id}\
|
||||
off input: {hit.input.id}, text: {requested_text[:125]}"
|
||||
)
|
||||
return (Document(page_content=requested_text, metadata=metadata), hit.score)
|
||||
|
||||
docs_and_scores.append(
|
||||
(Document(page_content=requested_text, metadata=metadata), hit.score)
|
||||
)
|
||||
# Iterate over hits and retrieve metadata and text
|
||||
futures = [executor.submit(hit_to_document, hit) for hit in hits]
|
||||
docs_and_scores = [future.result() for future in futures]
|
||||
|
||||
return docs_and_scores
|
||||
|
||||
|
Reference in New Issue
Block a user