mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 07:50:39 +00:00
add embed_image
API to JinaEmbedding (#22416)
- **Description:** Add `embed_image` to JinaEmbedding to embed images - **Twitter handle:** https://x.com/JinaAI_
This commit is contained in:
parent
3e92ed8056
commit
a7ae16f912
@ -1,4 +1,7 @@
|
||||
import base64
|
||||
from os.path import exists
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -8,6 +11,18 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
|
||||
|
||||
|
||||
def is_local(url: str) -> bool:
|
||||
url_parsed = urlparse(url)
|
||||
if url_parsed.scheme in ("file", ""): # Possibly a local file
|
||||
return exists(url_parsed.path)
|
||||
return False
|
||||
|
||||
|
||||
def get_bytes_str(file_path: str) -> str:
|
||||
with open(file_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""Jina embedding models."""
|
||||
|
||||
@ -40,10 +55,10 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
values["session"] = session
|
||||
return values
|
||||
|
||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed(self, input: Any) -> List[List[float]]:
|
||||
# Call Jina AI Embedding API
|
||||
resp = self.session.post( # type: ignore
|
||||
JINA_API_URL, json={"input": texts, "model": self.model_name}
|
||||
JINA_API_URL, json={"input": input, "model": self.model_name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
@ -73,3 +88,18 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._embed([text])[0]
|
||||
|
||||
def embed_images(self, uris: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's image embedding endpoint.
|
||||
Args:
|
||||
uris: The list of uris to embed.
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
input = []
|
||||
for uri in uris:
|
||||
if is_local(uri):
|
||||
input.append({"bytes": get_bytes_str(uri)})
|
||||
else:
|
||||
input.append({"url": uri})
|
||||
return self._embed(input)
|
||||
|
Loading…
Reference in New Issue
Block a user