add embed_image API to JinaEmbedding (#22416)

- **Description:** Add `embed_image` to JinaEmbedding to embed images
 - **Twitter handle:** https://x.com/JinaAI_
This commit is contained in:
Joan Fontanals 2024-06-03 16:23:37 +02:00 committed by GitHub
parent 3e92ed8056
commit a7ae16f912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,7 @@
import base64
from os.path import exists
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests
from langchain_core.embeddings import Embeddings
@ -8,6 +11,18 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
def is_local(url: str) -> bool:
url_parsed = urlparse(url)
if url_parsed.scheme in ("file", ""): # Possibly a local file
return exists(url_parsed.path)
return False
def get_bytes_str(file_path: str) -> str:
with open(file_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
class JinaEmbeddings(BaseModel, Embeddings):
"""Jina embedding models."""
@ -40,10 +55,10 @@ class JinaEmbeddings(BaseModel, Embeddings):
values["session"] = session
return values
def _embed(self, texts: List[str]) -> List[List[float]]:
def _embed(self, input: Any) -> List[List[float]]:
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
JINA_API_URL, json={"input": texts, "model": self.model_name}
JINA_API_URL, json={"input": input, "model": self.model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
@ -73,3 +88,18 @@ class JinaEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
return self._embed([text])[0]
def embed_images(self, uris: List[str]) -> List[List[float]]:
"""Call out to Jina's image embedding endpoint.
Args:
uris: The list of uris to embed.
Returns:
List of embeddings, one for each text.
"""
input = []
for uri in uris:
if is_local(uri):
input.append({"bytes": get_bytes_str(uri)})
else:
input.append({"url": uri})
return self._embed(input)