mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
add embed_image
API to JinaEmbedding (#22416)
- **Description:** Add `embed_image` to JinaEmbedding to embed images - **Twitter handle:** https://x.com/JinaAI_
This commit is contained in:
parent
3e92ed8056
commit
a7ae16f912
@ -1,4 +1,7 @@
|
|||||||
|
import base64
|
||||||
|
from os.path import exists
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -8,6 +11,18 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|||||||
JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
|
JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
def is_local(url: str) -> bool:
|
||||||
|
url_parsed = urlparse(url)
|
||||||
|
if url_parsed.scheme in ("file", ""): # Possibly a local file
|
||||||
|
return exists(url_parsed.path)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_bytes_str(file_path: str) -> str:
|
||||||
|
with open(file_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
class JinaEmbeddings(BaseModel, Embeddings):
|
class JinaEmbeddings(BaseModel, Embeddings):
|
||||||
"""Jina embedding models."""
|
"""Jina embedding models."""
|
||||||
|
|
||||||
@ -40,10 +55,10 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
|||||||
values["session"] = session
|
values["session"] = session
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
def _embed(self, input: Any) -> List[List[float]]:
|
||||||
# Call Jina AI Embedding API
|
# Call Jina AI Embedding API
|
||||||
resp = self.session.post( # type: ignore
|
resp = self.session.post( # type: ignore
|
||||||
JINA_API_URL, json={"input": texts, "model": self.model_name}
|
JINA_API_URL, json={"input": input, "model": self.model_name}
|
||||||
).json()
|
).json()
|
||||||
if "data" not in resp:
|
if "data" not in resp:
|
||||||
raise RuntimeError(resp["detail"])
|
raise RuntimeError(resp["detail"])
|
||||||
@ -73,3 +88,18 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self._embed([text])[0]
|
return self._embed([text])[0]
|
||||||
|
|
||||||
|
def embed_images(self, uris: List[str]) -> List[List[float]]:
|
||||||
|
"""Call out to Jina's image embedding endpoint.
|
||||||
|
Args:
|
||||||
|
uris: The list of uris to embed.
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
input = []
|
||||||
|
for uri in uris:
|
||||||
|
if is_local(uri):
|
||||||
|
input.append({"bytes": get_bytes_str(uri)})
|
||||||
|
else:
|
||||||
|
input.append({"url": uri})
|
||||||
|
return self._embed(input)
|
||||||
|
Loading…
Reference in New Issue
Block a user