feat(llms): support ERNIE Embedding-V1 (#9370)

- Description: support [ERNIE
Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu),
which is part of ERNIE ecology
- Issue: None
- Dependencies: None
- Tag maintainer: @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
axiangcoding 2023-08-21 22:52:25 +08:00 committed by GitHub
parent f116e10d53
commit 05aa02005b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 0 deletions

View File

@ -0,0 +1,60 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ERNIE Embedding-V1\n",
"\n",
"[ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu) is a text representation model based on Baidu Wenxin's large-scale model technology, \n",
"which converts text into a vector form represented by numerical values, and is used in text retrieval, information recommendation, knowledge mining and other scenarios."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import ErnieEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = ErnieEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -28,6 +28,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.edenai import EdenAiEmbeddings from langchain.embeddings.edenai import EdenAiEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.embaas import EmbaasEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings
from langchain.embeddings.ernie import ErnieEmbeddings
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.gpt4all import GPT4AllEmbeddings from langchain.embeddings.gpt4all import GPT4AllEmbeddings
@ -101,6 +102,7 @@ __all__ = [
"LocalAIEmbeddings", "LocalAIEmbeddings",
"AwaEmbeddings", "AwaEmbeddings",
"HuggingFaceBgeEmbeddings", "HuggingFaceBgeEmbeddings",
"ErnieEmbeddings",
] ]

View File

@ -0,0 +1,102 @@
import logging
import threading
from typing import Dict, List, Optional
import requests
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class ErnieEmbeddings(BaseModel, Embeddings):
"""`Ernie Embeddings V1` embedding models."""
ernie_client_id: Optional[str] = None
ernie_client_secret: Optional[str] = None
access_token: Optional[str] = None
chunk_size: int = 16
model_name = "ErnieBot-Embedding-V1"
_lock = threading.Lock()
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["ernie_client_id"] = get_from_dict_or_env(
values,
"ernie_client_id",
"ERNIE_CLIENT_ID",
)
values["ernie_client_secret"] = get_from_dict_or_env(
values,
"ernie_client_secret",
"ERNIE_CLIENT_SECRET",
)
return values
def _embedding(self, json: object) -> dict:
base_url = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
)
resp = requests.post(
f"{base_url}/embedding-v1",
headers={
"Content-Type": "application/json",
},
params={"access_token": self.access_token},
json=json,
)
return resp.json()
def _refresh_access_token_with_lock(self) -> None:
with self._lock:
logger.debug("Refreshing access token")
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
resp = requests.post(
base_url,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
params={
"grant_type": "client_credentials",
"client_id": self.ernie_client_id,
"client_secret": self.ernie_client_secret,
},
)
self.access_token = str(resp.json().get("access_token"))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
if not self.access_token:
self._refresh_access_token_with_lock()
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self._embedding({"input": [text for text in chunk]})
if resp.get("error_code"):
if resp.get("error_code") == 111:
self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text for text in chunk]})
else:
raise ValueError(f"Error from Ernie: {resp}")
lst.extend([i["embedding"] for i in resp["data"]])
return lst
def embed_query(self, text: str) -> List[float]:
if not self.access_token:
self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text]})
if resp.get("error_code"):
if resp.get("error_code") == 111:
self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text]})
else:
raise ValueError(f"Error from Ernie: {resp}")
return resp["data"][0]["embedding"]

View File

@ -0,0 +1,41 @@
import pytest
from langchain.embeddings.ernie import ErnieEmbeddings
def test_embedding_documents_1() -> None:
documents = ["foo bar"]
embedding = ErnieEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 384
def test_embedding_documents_2() -> None:
documents = ["foo", "bar"]
embedding = ErnieEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
assert len(output[1]) == 384
def test_embedding_query() -> None:
query = "foo"
embedding = ErnieEmbeddings()
output = embedding.embed_query(query)
assert len(output) == 384
def test_max_chunks() -> None:
documents = [f"text-{i}" for i in range(20)]
embedding = ErnieEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 20
def test_too_many_chunks() -> None:
documents = [f"text-{i}" for i in range(20)]
embedding = ErnieEmbeddings(chunk_size=20)
with pytest.raises(ValueError):
embedding.embed_documents(documents)