diff --git a/libs/community/langchain_community/embeddings/jina.py b/libs/community/langchain_community/embeddings/jina.py index dc66409e60c..7c50faf46b1 100644 --- a/libs/community/langchain_community/embeddings/jina.py +++ b/libs/community/langchain_community/embeddings/jina.py @@ -1,4 +1,7 @@ +import base64 +from os.path import exists from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import requests from langchain_core.embeddings import Embeddings @@ -8,6 +11,18 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" +def is_local(url: str) -> bool: + url_parsed = urlparse(url) + if url_parsed.scheme in ("file", ""): # Possibly a local file + return exists(url_parsed.path) + return False + + +def get_bytes_str(file_path: str) -> str: + with open(file_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + class JinaEmbeddings(BaseModel, Embeddings): """Jina embedding models.""" @@ -40,10 +55,10 @@ class JinaEmbeddings(BaseModel, Embeddings): values["session"] = session return values - def _embed(self, texts: List[str]) -> List[List[float]]: + def _embed(self, input: Any) -> List[List[float]]: # Call Jina AI Embedding API resp = self.session.post( # type: ignore - JINA_API_URL, json={"input": texts, "model": self.model_name} + JINA_API_URL, json={"input": input, "model": self.model_name} ).json() if "data" not in resp: raise RuntimeError(resp["detail"]) @@ -73,3 +88,18 @@ class JinaEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ return self._embed([text])[0] + + def embed_images(self, uris: List[str]) -> List[List[float]]: + """Call out to Jina's image embedding endpoint. + Args: + uris: The list of uris to embed. + Returns: + List of embeddings, one for each text. + """ + input = [] + for uri in uris: + if is_local(uri): + input.append({"bytes": get_bytes_str(uri)}) + else: + input.append({"url": uri}) + return self._embed(input)