mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
adapt Jina Embeddings to new Jina AI Embedding API (#13658)
- **Description:** Adapt JinaEmbeddings to run with the new Jina AI Embedding platform - **Twitter handle:** https://twitter.com/JinaAI_ --------- Co-authored-by: Joan Fontanals Martinez <joan.fontanals.martinez@jina.ai> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
@@ -7,69 +6,54 @@ from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""Jina embedding models."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
model_name: str = "ViT-B-32::openai"
|
||||
"""Model name to use."""
|
||||
|
||||
jina_auth_token: Optional[str] = None
|
||||
jina_api_url: str = "https://api.clip.jina.ai/api/v1/models/"
|
||||
request_headers: Optional[dict] = None
|
||||
session: Any #: :meta private:
|
||||
model_name: str = "jina-embeddings-v2-base-en"
|
||||
jina_api_key: Optional[str] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that auth token exists in environment."""
|
||||
# Set Auth
|
||||
jina_auth_token = get_from_dict_or_env(
|
||||
values, "jina_auth_token", "JINA_AUTH_TOKEN"
|
||||
try:
|
||||
jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
|
||||
except ValueError as original_exc:
|
||||
try:
|
||||
jina_api_key = get_from_dict_or_env(
|
||||
values, "jina_auth_token", "JINA_AUTH_TOKEN"
|
||||
)
|
||||
except ValueError:
|
||||
raise original_exc
|
||||
session = requests.Session()
|
||||
session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {jina_api_key}",
|
||||
"Accept-Encoding": "identity",
|
||||
"Content-type": "application/json",
|
||||
}
|
||||
)
|
||||
values["jina_auth_token"] = jina_auth_token
|
||||
values["request_headers"] = (("authorization", jina_auth_token),)
|
||||
|
||||
# Test that package is installed
|
||||
try:
|
||||
import jina
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import `jina` python package. "
|
||||
"Please install it with `pip install jina`."
|
||||
)
|
||||
|
||||
# Setup client
|
||||
jina_api_url = os.environ.get("JINA_API_URL", values["jina_api_url"])
|
||||
model_name = values["model_name"]
|
||||
try:
|
||||
resp = requests.get(
|
||||
jina_api_url + f"?model_name={model_name}",
|
||||
headers={"Authorization": jina_auth_token},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
raise ValueError(
|
||||
"The given Jina auth token is invalid. "
|
||||
"Please check your Jina auth token."
|
||||
)
|
||||
elif resp.status_code == 404:
|
||||
raise ValueError(
|
||||
f"The given model name `{model_name}` is not valid. "
|
||||
f"Please go to https://cloud.jina.ai/user/inference "
|
||||
f"and create a model with the given model name."
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
endpoint = resp.json()["endpoints"]["grpc"]
|
||||
values["client"] = jina.Client(host=endpoint)
|
||||
except requests.exceptions.HTTPError as err:
|
||||
raise ValueError(f"Error: {err!r}")
|
||||
values["session"] = session
|
||||
return values
|
||||
|
||||
def _post(self, docs: List[Any], **kwargs: Any) -> Any:
|
||||
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
|
||||
return self.client.post(on="/encode", **payload)
|
||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||
# Call Jina AI Embedding API
|
||||
resp = self.session.post( # type: ignore
|
||||
JINA_API_URL, json={"input": texts, "model": self.model_name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
|
||||
embeddings = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
@@ -78,12 +62,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
from docarray import Document, DocumentArray
|
||||
|
||||
embeddings = self._post(
|
||||
docs=DocumentArray([Document(text=t) for t in texts])
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
return self._embed(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
@@ -92,7 +71,4 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
from docarray import Document, DocumentArray
|
||||
|
||||
embedding = self._post(docs=DocumentArray([Document(text=text)])).embeddings[0]
|
||||
return list(map(float, embedding))
|
||||
return self._embed([text])[0]
|
||||
|
Reference in New Issue
Block a user