mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
Updated integration with Clarifai python SDK functions (#13671)
Description : Updated the functions with new Clarifai python SDK. Enabled initialisation of Clarifai class with model URL. Updated docs with new functions examples.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
@@ -20,15 +20,15 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import ClarifaiEmbeddings
|
||||
clarifai = ClarifaiEmbeddings(
|
||||
model="embed-english-light-v3.0", clarifai_api_key="my-api-key"
|
||||
)
|
||||
clarifai = ClarifaiEmbeddings(user_id=USER_ID,
|
||||
app_id=APP_ID,
|
||||
model_id=MODEL_ID)
|
||||
(or)
|
||||
clarifai_llm = Clarifai(model_url=EXAMPLE_URL)
|
||||
"""
|
||||
|
||||
stub: Any #: :meta private:
|
||||
"""Clarifai stub."""
|
||||
userDataObject: Any
|
||||
"""Clarifai user data object."""
|
||||
model_url: Optional[str] = None
|
||||
"""Model url to use."""
|
||||
model_id: Optional[str] = None
|
||||
"""Model id to use."""
|
||||
model_version_id: Optional[str] = None
|
||||
@@ -48,37 +48,24 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
"""Validate that we have all required info to access Clarifai
|
||||
platform and python package exists in environment."""
|
||||
|
||||
values["pat"] = get_from_dict_or_env(values, "pat", "CLARIFAI_PAT")
|
||||
user_id = values.get("user_id")
|
||||
app_id = values.get("app_id")
|
||||
model_id = values.get("model_id")
|
||||
model_url = values.get("model_url")
|
||||
|
||||
if values["pat"] is None:
|
||||
raise ValueError("Please provide a pat.")
|
||||
if user_id is None:
|
||||
raise ValueError("Please provide a user_id.")
|
||||
if app_id is None:
|
||||
raise ValueError("Please provide a app_id.")
|
||||
if model_id is None:
|
||||
raise ValueError("Please provide a model_id.")
|
||||
if model_url is not None and model_id is not None:
|
||||
raise ValueError("Please provide either model_url or model_id, not both.")
|
||||
|
||||
try:
|
||||
from clarifai.client import create_stub
|
||||
from clarifai.client.auth.helper import ClarifaiAuthHelper
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
auth = ClarifaiAuthHelper(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
pat=values["pat"],
|
||||
base=values["api_base"],
|
||||
)
|
||||
values["userDataObject"] = auth.get_user_app_id_proto()
|
||||
values["stub"] = create_stub(auth)
|
||||
if model_url is None and model_id is None:
|
||||
raise ValueError("Please provide one of model_url or model_id.")
|
||||
|
||||
if model_url is None and model_id is not None:
|
||||
if user_id is None or app_id is None:
|
||||
raise ValueError("Please provide a user_id and app_id.")
|
||||
|
||||
return values
|
||||
|
||||
@@ -91,57 +78,48 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import (
|
||||
resources_pb2,
|
||||
service_pb2,
|
||||
)
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from clarifai.client.input import Inputs
|
||||
from clarifai.client.model import Model
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
if self.pat is not None:
|
||||
pat = self.pat
|
||||
if self.model_url is not None:
|
||||
_model_init = Model(url=self.model_url, pat=pat)
|
||||
else:
|
||||
_model_init = Model(
|
||||
model_id=self.model_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
pat=pat,
|
||||
)
|
||||
|
||||
input_obj = Inputs(pat=pat)
|
||||
batch_size = 32
|
||||
embeddings = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
|
||||
)
|
||||
for t in batch
|
||||
],
|
||||
)
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_output_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs)
|
||||
else None
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_output_failure}"
|
||||
)
|
||||
embeddings.extend(
|
||||
[
|
||||
list(o.data.embeddings[0].vector)
|
||||
for o in post_model_outputs_response.outputs
|
||||
try:
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
input_batch = [
|
||||
input_obj.get_text_input(input_id=str(id), raw_text=inp)
|
||||
for id, inp in enumerate(batch)
|
||||
]
|
||||
)
|
||||
predict_response = _model_init.predict(input_batch)
|
||||
embeddings.extend(
|
||||
[
|
||||
list(output.data.embeddings[0].vector)
|
||||
for output in predict_response.outputs
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Predict failed, exception: {e}")
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@@ -153,48 +131,34 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import (
|
||||
resources_pb2,
|
||||
service_pb2,
|
||||
)
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from clarifai.client.model import Model
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=text))
|
||||
)
|
||||
],
|
||||
)
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_output_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs[0])
|
||||
else None
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_output_failure}"
|
||||
if self.pat is not None:
|
||||
pat = self.pat
|
||||
if self.model_url is not None:
|
||||
_model_init = Model(url=self.model_url, pat=pat)
|
||||
else:
|
||||
_model_init = Model(
|
||||
model_id=self.model_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
pat=pat,
|
||||
)
|
||||
|
||||
embeddings = [
|
||||
list(o.data.embeddings[0].vector)
|
||||
for o in post_model_outputs_response.outputs
|
||||
]
|
||||
try:
|
||||
predict_response = _model_init.predict_by_bytes(
|
||||
bytes(text, "utf-8"), input_type="text"
|
||||
)
|
||||
embeddings = [
|
||||
list(op.data.embeddings[0].vector) for op in predict_response.outputs
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Predict failed, exception: {e}")
|
||||
|
||||
return embeddings[0]
|
||||
|
@@ -12,6 +12,9 @@ from langchain.utils import get_from_dict_or_env
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
EXAMPLE_URL = "https://clarifai.com/openai/chat-completion/models/GPT-4"
|
||||
|
||||
|
||||
class Clarifai(LLM):
|
||||
"""Clarifai large language models.
|
||||
|
||||
@@ -24,27 +27,23 @@ class Clarifai(LLM):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Clarifai
|
||||
clarifai_llm = Clarifai(pat=CLARIFAI_PAT, \
|
||||
user_id=USER_ID, app_id=APP_ID, model_id=MODEL_ID)
|
||||
clarifai_llm = Clarifai(user_id=USER_ID, app_id=APP_ID, model_id=MODEL_ID)
|
||||
(or)
|
||||
clarifai_llm = Clarifai(model_url=EXAMPLE_URL)
|
||||
"""
|
||||
|
||||
stub: Any #: :meta private:
|
||||
userDataObject: Any
|
||||
|
||||
model_url: Optional[str] = None
|
||||
"""Model url to use."""
|
||||
model_id: Optional[str] = None
|
||||
"""Model id to use."""
|
||||
|
||||
model_version_id: Optional[str] = None
|
||||
"""Model version id to use."""
|
||||
|
||||
app_id: Optional[str] = None
|
||||
"""Clarifai application id to use."""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
"""Clarifai user id to use."""
|
||||
|
||||
pat: Optional[str] = None
|
||||
|
||||
"""Clarifai personal access token to use."""
|
||||
api_base: str = "https://api.clarifai.com"
|
||||
|
||||
class Config:
|
||||
@@ -60,32 +59,17 @@ class Clarifai(LLM):
|
||||
user_id = values.get("user_id")
|
||||
app_id = values.get("app_id")
|
||||
model_id = values.get("model_id")
|
||||
model_url = values.get("model_url")
|
||||
|
||||
if values["pat"] is None:
|
||||
raise ValueError("Please provide a pat.")
|
||||
if user_id is None:
|
||||
raise ValueError("Please provide a user_id.")
|
||||
if app_id is None:
|
||||
raise ValueError("Please provide a app_id.")
|
||||
if model_id is None:
|
||||
raise ValueError("Please provide a model_id.")
|
||||
if model_url is not None and model_id is not None:
|
||||
raise ValueError("Please provide either model_url or model_id, not both.")
|
||||
|
||||
try:
|
||||
from clarifai.client import create_stub
|
||||
from clarifai.client.auth.helper import ClarifaiAuthHelper
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
auth = ClarifaiAuthHelper(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
pat=values["pat"],
|
||||
base=values["api_base"],
|
||||
)
|
||||
values["userDataObject"] = auth.get_user_app_id_proto()
|
||||
values["stub"] = create_stub(auth)
|
||||
if model_url is None and model_id is None:
|
||||
raise ValueError("Please provide one of model_url or model_id.")
|
||||
|
||||
if model_url is None and model_id is not None:
|
||||
if user_id is None or app_id is None:
|
||||
raise ValueError("Please provide a user_id and app_id.")
|
||||
|
||||
return values
|
||||
|
||||
@@ -99,6 +83,7 @@ class Clarifai(LLM):
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{
|
||||
"model_url": self.model_url,
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"model_id": self.model_id,
|
||||
@@ -115,6 +100,7 @@ class Clarifai(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
inference_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Clarfai's PostModelOutputs endpoint.
|
||||
@@ -131,54 +117,39 @@ class Clarifai(LLM):
|
||||
|
||||
response = clarifai_llm("Tell me a joke.")
|
||||
"""
|
||||
|
||||
# If version_id None, Defaults to the latest model version
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import (
|
||||
resources_pb2,
|
||||
service_pb2,
|
||||
)
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from clarifai.client.model import Model
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
|
||||
# The userDataObject is created in the overview and
|
||||
# is required when using a PAT
|
||||
# If version_id None, Defaults to the latest model version
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=prompt))
|
||||
)
|
||||
],
|
||||
)
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_model_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs)
|
||||
else None
|
||||
if self.pat is not None:
|
||||
pat = self.pat
|
||||
if self.model_url is not None:
|
||||
_model_init = Model(url=self.model_url, pat=pat)
|
||||
else:
|
||||
_model_init = Model(
|
||||
model_id=self.model_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
pat=pat,
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_model_failure}"
|
||||
try:
|
||||
(inference_params := {}) if inference_params is None else inference_params
|
||||
predict_response = _model_init.predict_by_bytes(
|
||||
bytes(prompt, "utf-8"),
|
||||
input_type="text",
|
||||
inference_params=inference_params,
|
||||
)
|
||||
text = predict_response.outputs[0].data.text.raw
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
text = post_model_outputs_response.outputs[0].data.text.raw
|
||||
except Exception as e:
|
||||
logger.error(f"Predict failed, exception: {e}")
|
||||
|
||||
# In order to make this consistent with other endpoints, we strip them.
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
def _generate(
|
||||
@@ -186,56 +157,50 @@ class Clarifai(LLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
inference_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
# TODO: add caching here.
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import (
|
||||
resources_pb2,
|
||||
service_pb2,
|
||||
)
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from clarifai.client.input import Inputs
|
||||
from clarifai.client.model import Model
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
if self.pat is not None:
|
||||
pat = self.pat
|
||||
if self.model_url is not None:
|
||||
_model_init = Model(url=self.model_url, pat=pat)
|
||||
else:
|
||||
_model_init = Model(
|
||||
model_id=self.model_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
pat=pat,
|
||||
)
|
||||
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
batch_size = 32
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
batch = prompts[i : i + batch_size]
|
||||
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
|
||||
user_app_id=self.userDataObject,
|
||||
model_id=self.model_id,
|
||||
version_id=self.model_version_id,
|
||||
inputs=[
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(text=resources_pb2.Text(raw=prompt))
|
||||
)
|
||||
for prompt in batch
|
||||
],
|
||||
)
|
||||
post_model_outputs_response = self.stub.PostModelOutputs(
|
||||
post_model_outputs_request
|
||||
)
|
||||
|
||||
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_model_outputs_response.status)
|
||||
first_model_failure = (
|
||||
post_model_outputs_response.outputs[0].status
|
||||
if len(post_model_outputs_response.outputs)
|
||||
else None
|
||||
)
|
||||
raise Exception(
|
||||
f"Post model outputs failed, status: "
|
||||
f"{post_model_outputs_response.status}, first output failure: "
|
||||
f"{first_model_failure}"
|
||||
input_obj = Inputs(pat=pat)
|
||||
try:
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
batch = prompts[i : i + batch_size]
|
||||
input_batch = [
|
||||
input_obj.get_text_input(input_id=str(id), raw_text=inp)
|
||||
for id, inp in enumerate(batch)
|
||||
]
|
||||
(
|
||||
inference_params := {}
|
||||
) if inference_params is None else inference_params
|
||||
predict_response = _model_init.predict(
|
||||
inputs=input_batch, inference_params=inference_params
|
||||
)
|
||||
|
||||
for output in post_model_outputs_response.outputs:
|
||||
for output in predict_response.outputs:
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(output.data.text.raw, stop)
|
||||
else:
|
||||
@@ -243,4 +208,7 @@ class Clarifai(LLM):
|
||||
|
||||
generations.append([Generation(text=text)])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Predict failed, exception: {e}")
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from google.protobuf.struct_pb2 import Struct
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
@@ -17,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
class Clarifai(VectorStore):
|
||||
"""`Clarifai AI` vector store.
|
||||
|
||||
To use, you should have the ``clarifai`` python package installed.
|
||||
To use, you should have the ``clarifai`` python SDK package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -33,9 +35,8 @@ class Clarifai(VectorStore):
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
pat: Optional[str] = None,
|
||||
number_of_docs: Optional[int] = None,
|
||||
api_base: Optional[str] = None,
|
||||
pat: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize with Clarifai client.
|
||||
|
||||
@@ -50,21 +51,11 @@ class Clarifai(VectorStore):
|
||||
Raises:
|
||||
ValueError: If user ID, app ID or personal access token is not provided.
|
||||
"""
|
||||
try:
|
||||
from clarifai.auth.helper import DEFAULT_BASE, ClarifaiAuthHelper
|
||||
from clarifai.client import create_stub
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
self._api_base = DEFAULT_BASE
|
||||
|
||||
self._user_id = user_id or os.environ.get("CLARIFAI_USER_ID")
|
||||
self._app_id = app_id or os.environ.get("CLARIFAI_APP_ID")
|
||||
self._pat = pat or os.environ.get("CLARIFAI_PAT")
|
||||
if pat:
|
||||
os.environ["CLARIFAI_PAT"] = pat
|
||||
self._pat = os.environ.get("CLARIFAI_PAT")
|
||||
if self._user_id is None or self._app_id is None or self._pat is None:
|
||||
raise ValueError(
|
||||
"Could not find CLARIFAI_USER_ID, CLARIFAI_APP_ID or\
|
||||
@@ -73,77 +64,8 @@ class Clarifai(VectorStore):
|
||||
app ID and personal access token \
|
||||
from https://clarifai.com/settings/security."
|
||||
)
|
||||
|
||||
self._auth = ClarifaiAuthHelper(
|
||||
user_id=self._user_id,
|
||||
app_id=self._app_id,
|
||||
pat=self._pat,
|
||||
base=self._api_base,
|
||||
)
|
||||
self._stub = create_stub(self._auth)
|
||||
self._userDataObject = self._auth.get_user_app_id_proto()
|
||||
self._number_of_docs = number_of_docs
|
||||
|
||||
def _post_texts_as_inputs(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> List[str]:
|
||||
"""Post text to Clarifai and return the ID of the input.
|
||||
|
||||
Args:
|
||||
text (str): Text to post.
|
||||
metadata (dict): Metadata to post.
|
||||
|
||||
Returns:
|
||||
str: ID of the input.
|
||||
"""
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from google.protobuf.struct_pb2 import Struct # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
) from e
|
||||
|
||||
if metadatas is not None:
|
||||
assert len(list(texts)) == len(
|
||||
metadatas
|
||||
), "Number of texts and metadatas should be the same."
|
||||
|
||||
inputs = []
|
||||
for idx, text in enumerate(texts):
|
||||
if metadatas is not None:
|
||||
input_metadata = Struct()
|
||||
input_metadata.update(metadatas[idx])
|
||||
inputs.append(
|
||||
resources_pb2.Input(
|
||||
data=resources_pb2.Data(
|
||||
text=resources_pb2.Text(raw=text),
|
||||
metadata=input_metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
post_inputs_response = self._stub.PostInputs(
|
||||
service_pb2.PostInputsRequest(
|
||||
user_app_id=self._userDataObject,
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if post_inputs_response.status.code != status_code_pb2.SUCCESS:
|
||||
logger.error(post_inputs_response.status)
|
||||
raise Exception(
|
||||
"Post inputs failed, status: " + post_inputs_response.status.description
|
||||
)
|
||||
|
||||
input_ids = []
|
||||
for input in post_inputs_response.inputs:
|
||||
input_ids.append(input.id)
|
||||
|
||||
return input_ids
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@@ -162,9 +84,14 @@ class Clarifai(VectorStore):
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
try:
|
||||
from clarifai.client.input import Inputs
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
"Please install it with `pip install clarifai`."
|
||||
) from e
|
||||
|
||||
ltexts = list(texts)
|
||||
length = len(ltexts)
|
||||
@@ -175,29 +102,51 @@ class Clarifai(VectorStore):
|
||||
metadatas
|
||||
), "Number of texts and metadatas should be the same."
|
||||
|
||||
if ids is not None:
|
||||
assert len(ltexts) == len(
|
||||
ids
|
||||
), "Number of text inputs and input ids should be the same."
|
||||
|
||||
input_obj = Inputs(app_id=self._app_id, user_id=self._user_id)
|
||||
batch_size = 32
|
||||
input_ids = []
|
||||
input_job_ids = []
|
||||
for idx in range(0, length, batch_size):
|
||||
try:
|
||||
batch_texts = ltexts[idx : idx + batch_size]
|
||||
batch_metadatas = (
|
||||
metadatas[idx : idx + batch_size] if metadatas else None
|
||||
)
|
||||
result_ids = self._post_texts_as_inputs(batch_texts, batch_metadatas)
|
||||
input_ids.extend(result_ids)
|
||||
logger.debug(f"Input {result_ids} posted successfully.")
|
||||
if batch_metadatas is not None:
|
||||
meta_list = []
|
||||
for meta in batch_metadatas:
|
||||
meta_struct = Struct()
|
||||
meta_struct.update(meta)
|
||||
meta_list.append(meta_struct)
|
||||
if ids is None:
|
||||
ids = [uuid.uuid4().hex for _ in range(len(batch_texts))]
|
||||
input_batch = [
|
||||
input_obj.get_text_input(
|
||||
input_id=ids[id],
|
||||
raw_text=inp,
|
||||
metadata=meta_list[id] if batch_metadatas else None,
|
||||
)
|
||||
for id, inp in enumerate(batch_texts)
|
||||
]
|
||||
result_id = input_obj.upload_inputs(inputs=input_batch)
|
||||
input_job_ids.extend(result_id)
|
||||
logger.debug("Input posted successfully.")
|
||||
|
||||
except Exception as error:
|
||||
logger.warning(f"Post inputs failed: {error}")
|
||||
traceback.print_exc()
|
||||
|
||||
return input_ids
|
||||
return input_job_ids
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
namespace: Optional[str] = None,
|
||||
filters: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Run similarity search with score using Clarifai.
|
||||
@@ -212,10 +161,9 @@ class Clarifai(VectorStore):
|
||||
List[Document]: List of documents most similar to the query text.
|
||||
"""
|
||||
try:
|
||||
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
||||
from clarifai_grpc.grpc.api.status import status_code_pb2
|
||||
from clarifai.client.search import Search
|
||||
from clarifai_grpc.grpc.api import resources_pb2
|
||||
from google.protobuf import json_format # type: ignore
|
||||
from google.protobuf.struct_pb2 import Struct # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import clarifai python package. "
|
||||
@@ -226,50 +174,22 @@ class Clarifai(VectorStore):
|
||||
if self._number_of_docs is not None:
|
||||
k = self._number_of_docs
|
||||
|
||||
req = service_pb2.PostAnnotationsSearchesRequest(
|
||||
user_app_id=self._userDataObject,
|
||||
searches=[
|
||||
resources_pb2.Search(
|
||||
query=resources_pb2.Query(
|
||||
ranks=[
|
||||
resources_pb2.Rank(
|
||||
annotation=resources_pb2.Annotation(
|
||||
data=resources_pb2.Data(
|
||||
text=resources_pb2.Text(raw=query),
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
pagination=service_pb2.Pagination(page=1, per_page=k),
|
||||
)
|
||||
|
||||
search_obj = Search(user_id=self._user_id, app_id=self._app_id, top_k=k)
|
||||
rank = [{"text_raw": query}]
|
||||
# Add filter by metadata if provided.
|
||||
if filter is not None:
|
||||
search_metadata = Struct()
|
||||
search_metadata.update(filter)
|
||||
f = req.searches[0].query.filters.add()
|
||||
f.annotation.data.metadata.update(search_metadata)
|
||||
|
||||
post_annotations_searches_response = self._stub.PostAnnotationsSearches(req)
|
||||
|
||||
# Check if search was successful
|
||||
if post_annotations_searches_response.status.code != status_code_pb2.SUCCESS:
|
||||
raise Exception(
|
||||
"Post searches failed, status: "
|
||||
+ post_annotations_searches_response.status.description
|
||||
)
|
||||
if filters is not None:
|
||||
search_metadata = {"metadata": filters}
|
||||
search_response = search_obj.query(ranks=rank, filters=[search_metadata])
|
||||
else:
|
||||
search_response = search_obj.query(ranks=rank)
|
||||
|
||||
# Retrieve hits
|
||||
hits = post_annotations_searches_response.hits
|
||||
|
||||
hits = [hit for data in search_response for hit in data.hits]
|
||||
executor = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def hit_to_document(hit: resources_pb2.Hit) -> Tuple[Document, float]:
|
||||
metadata = json_format.MessageToDict(hit.input.data.metadata)
|
||||
h = {"Authorization": f"Key {self._auth.pat}"}
|
||||
h = {"Authorization": f"Key {self._pat}"}
|
||||
request = requests.get(hit.input.data.text.url, headers=h)
|
||||
|
||||
# override encoding by real educated guess as provided by chardet
|
||||
@@ -314,9 +234,8 @@ class Clarifai(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
pat: Optional[str] = None,
|
||||
number_of_docs: Optional[int] = None,
|
||||
api_base: Optional[str] = None,
|
||||
pat: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Clarifai:
|
||||
"""Create a Clarifai vectorstore from a list of texts.
|
||||
@@ -325,10 +244,8 @@ class Clarifai(VectorStore):
|
||||
user_id (str): User ID.
|
||||
app_id (str): App ID.
|
||||
texts (List[str]): List of texts to add.
|
||||
pat (Optional[str]): Personal access token. Defaults to None.
|
||||
number_of_docs (Optional[int]): Number of documents to return
|
||||
during vector search. Defaults to None.
|
||||
api_base (Optional[str]): API base. Defaults to None.
|
||||
metadatas (Optional[List[dict]]): Optional list of metadatas.
|
||||
Defaults to None.
|
||||
|
||||
@@ -338,9 +255,8 @@ class Clarifai(VectorStore):
|
||||
clarifai_vector_db = cls(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
pat=pat,
|
||||
number_of_docs=number_of_docs,
|
||||
api_base=api_base,
|
||||
pat=pat,
|
||||
)
|
||||
clarifai_vector_db.add_texts(texts=texts, metadatas=metadatas)
|
||||
return clarifai_vector_db
|
||||
@@ -352,9 +268,8 @@ class Clarifai(VectorStore):
|
||||
embedding: Optional[Embeddings] = None,
|
||||
user_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
pat: Optional[str] = None,
|
||||
number_of_docs: Optional[int] = None,
|
||||
api_base: Optional[str] = None,
|
||||
pat: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Clarifai:
|
||||
"""Create a Clarifai vectorstore from a list of documents.
|
||||
@@ -363,10 +278,8 @@ class Clarifai(VectorStore):
|
||||
user_id (str): User ID.
|
||||
app_id (str): App ID.
|
||||
documents (List[Document]): List of documents to add.
|
||||
pat (Optional[str]): Personal access token. Defaults to None.
|
||||
number_of_docs (Optional[int]): Number of documents to return
|
||||
during vector search. Defaults to None.
|
||||
api_base (Optional[str]): API base. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Clarifai: Clarifai vectorstore.
|
||||
@@ -377,8 +290,7 @@ class Clarifai(VectorStore):
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
texts=texts,
|
||||
pat=pat,
|
||||
number_of_docs=number_of_docs,
|
||||
api_base=api_base,
|
||||
pat=pat,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
Reference in New Issue
Block a user