mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
feat(llms): support ERNIE Embedding-V1 (#9370)
- Description: support [ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu), which is part of ERNIE ecology - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
f116e10d53
commit
05aa02005b
60
docs/extras/integrations/text_embedding/ernie.ipynb
Normal file
60
docs/extras/integrations/text_embedding/ernie.ipynb
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ERNIE Embedding-V1\n",
|
||||||
|
"\n",
|
||||||
|
"[ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu) is a text representation model based on Baidu Wenxin's large-scale model technology, \n",
|
||||||
|
"which converts text into a vector form represented by numerical values, and is used in text retrieval, information recommendation, knowledge mining and other scenarios."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import ErnieEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = ErnieEmbeddings()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query_result = embeddings.embed_query(\"foo\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doc_results = embeddings.embed_documents([\"foo\"])"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -28,6 +28,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
|||||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||||
from langchain.embeddings.embaas import EmbaasEmbeddings
|
from langchain.embeddings.embaas import EmbaasEmbeddings
|
||||||
|
from langchain.embeddings.ernie import ErnieEmbeddings
|
||||||
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
|
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
|
||||||
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
||||||
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
|
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
|
||||||
@ -101,6 +102,7 @@ __all__ = [
|
|||||||
"LocalAIEmbeddings",
|
"LocalAIEmbeddings",
|
||||||
"AwaEmbeddings",
|
"AwaEmbeddings",
|
||||||
"HuggingFaceBgeEmbeddings",
|
"HuggingFaceBgeEmbeddings",
|
||||||
|
"ErnieEmbeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
102
libs/langchain/langchain/embeddings/ernie.py
Normal file
102
libs/langchain/langchain/embeddings/ernie.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""`Ernie Embeddings V1` embedding models."""
|
||||||
|
|
||||||
|
ernie_client_id: Optional[str] = None
|
||||||
|
ernie_client_secret: Optional[str] = None
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
|
||||||
|
chunk_size: int = 16
|
||||||
|
|
||||||
|
model_name = "ErnieBot-Embedding-V1"
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
values["ernie_client_id"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"ernie_client_id",
|
||||||
|
"ERNIE_CLIENT_ID",
|
||||||
|
)
|
||||||
|
values["ernie_client_secret"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"ernie_client_secret",
|
||||||
|
"ERNIE_CLIENT_SECRET",
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _embedding(self, json: object) -> dict:
|
||||||
|
base_url = (
|
||||||
|
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
|
||||||
|
)
|
||||||
|
resp = requests.post(
|
||||||
|
f"{base_url}/embedding-v1",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
params={"access_token": self.access_token},
|
||||||
|
json=json,
|
||||||
|
)
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def _refresh_access_token_with_lock(self) -> None:
|
||||||
|
with self._lock:
|
||||||
|
logger.debug("Refreshing access token")
|
||||||
|
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
||||||
|
resp = requests.post(
|
||||||
|
base_url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
params={
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"client_id": self.ernie_client_id,
|
||||||
|
"client_secret": self.ernie_client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.access_token = str(resp.json().get("access_token"))
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
if not self.access_token:
|
||||||
|
self._refresh_access_token_with_lock()
|
||||||
|
text_in_chunks = [
|
||||||
|
texts[i : i + self.chunk_size]
|
||||||
|
for i in range(0, len(texts), self.chunk_size)
|
||||||
|
]
|
||||||
|
lst = []
|
||||||
|
for chunk in text_in_chunks:
|
||||||
|
resp = self._embedding({"input": [text for text in chunk]})
|
||||||
|
if resp.get("error_code"):
|
||||||
|
if resp.get("error_code") == 111:
|
||||||
|
self._refresh_access_token_with_lock()
|
||||||
|
resp = self._embedding({"input": [text for text in chunk]})
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Error from Ernie: {resp}")
|
||||||
|
lst.extend([i["embedding"] for i in resp["data"]])
|
||||||
|
return lst
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
if not self.access_token:
|
||||||
|
self._refresh_access_token_with_lock()
|
||||||
|
resp = self._embedding({"input": [text]})
|
||||||
|
if resp.get("error_code"):
|
||||||
|
if resp.get("error_code") == 111:
|
||||||
|
self._refresh_access_token_with_lock()
|
||||||
|
resp = self._embedding({"input": [text]})
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Error from Ernie: {resp}")
|
||||||
|
return resp["data"][0]["embedding"]
|
@ -0,0 +1,41 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.embeddings.ernie import ErnieEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_documents_1() -> None:
|
||||||
|
documents = ["foo bar"]
|
||||||
|
embedding = ErnieEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert len(output[0]) == 384
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_documents_2() -> None:
|
||||||
|
documents = ["foo", "bar"]
|
||||||
|
embedding = ErnieEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 2
|
||||||
|
assert len(output[0]) == 384
|
||||||
|
assert len(output[1]) == 384
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_query() -> None:
|
||||||
|
query = "foo"
|
||||||
|
embedding = ErnieEmbeddings()
|
||||||
|
output = embedding.embed_query(query)
|
||||||
|
assert len(output) == 384
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_chunks() -> None:
|
||||||
|
documents = [f"text-{i}" for i in range(20)]
|
||||||
|
embedding = ErnieEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_too_many_chunks() -> None:
|
||||||
|
documents = [f"text-{i}" for i in range(20)]
|
||||||
|
embedding = ErnieEmbeddings(chunk_size=20)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding.embed_documents(documents)
|
Loading…
Reference in New Issue
Block a user