From 05aa02005bf3d06a22c5d8f5264bd826992c275b Mon Sep 17 00:00:00 2001 From: axiangcoding Date: Mon, 21 Aug 2023 22:52:25 +0800 Subject: [PATCH] feat(llms): support ERNIE Embedding-V1 (#9370) - Description: support [ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu), which is part of ERNIE ecology - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --------- Co-authored-by: Bagatur --- .../integrations/text_embedding/ernie.ipynb | 60 +++++++++++ .../langchain/embeddings/__init__.py | 2 + libs/langchain/langchain/embeddings/ernie.py | 102 ++++++++++++++++++ .../embeddings/test_ernie.py | 41 +++++++ 4 files changed, 205 insertions(+) create mode 100644 docs/extras/integrations/text_embedding/ernie.ipynb create mode 100644 libs/langchain/langchain/embeddings/ernie.py create mode 100644 libs/langchain/tests/integration_tests/embeddings/test_ernie.py diff --git a/docs/extras/integrations/text_embedding/ernie.ipynb b/docs/extras/integrations/text_embedding/ernie.ipynb new file mode 100644 index 00000000000..80b563eae94 --- /dev/null +++ b/docs/extras/integrations/text_embedding/ernie.ipynb @@ -0,0 +1,60 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ERNIE Embedding-V1\n", + "\n", + "[ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu) is a text representation model based on Baidu Wenxin's large-scale model technology, \n", + "which converts text into a vector form represented by numerical values, and is used in text retrieval, information recommendation, knowledge mining and other scenarios." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import ErnieEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = ErnieEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "doc_results = embeddings.embed_documents([\"foo\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index 1c03d2c4042..87cb5e90d5a 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -28,6 +28,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings from langchain.embeddings.edenai import EdenAiEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings +from langchain.embeddings.ernie import ErnieEmbeddings from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.gpt4all import GPT4AllEmbeddings @@ -101,6 +102,7 @@ __all__ = [ "LocalAIEmbeddings", "AwaEmbeddings", "HuggingFaceBgeEmbeddings", + "ErnieEmbeddings", ] diff --git a/libs/langchain/langchain/embeddings/ernie.py b/libs/langchain/langchain/embeddings/ernie.py new file mode 100644 index 00000000000..b8213651adc --- /dev/null +++ b/libs/langchain/langchain/embeddings/ernie.py @@ -0,0 +1,102 @@ +import logging +import threading +from typing import Dict, List, Optional + +import requests + +from langchain.embeddings.base import Embeddings +from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class ErnieEmbeddings(BaseModel, Embeddings): + """`Ernie Embeddings V1` embedding models.""" + + ernie_client_id: Optional[str] = None + ernie_client_secret: Optional[str] = None + access_token: Optional[str] = None + + chunk_size: int = 16 + + model_name = "ErnieBot-Embedding-V1" + + _lock = threading.Lock() + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["ernie_client_id"] = get_from_dict_or_env( + values, + "ernie_client_id", + "ERNIE_CLIENT_ID", + ) + values["ernie_client_secret"] = get_from_dict_or_env( + values, + "ernie_client_secret", + "ERNIE_CLIENT_SECRET", + ) + return values + + def _embedding(self, json: object) -> dict: + base_url = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings" + ) + resp = requests.post( + f"{base_url}/embedding-v1", + headers={ + "Content-Type": "application/json", + }, + params={"access_token": self.access_token}, + json=json, + ) + return resp.json() + + def _refresh_access_token_with_lock(self) -> None: + with self._lock: + logger.debug("Refreshing access token") + base_url: str = "https://aip.baidubce.com/oauth/2.0/token" + resp = requests.post( + base_url, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + params={ + "grant_type": "client_credentials", + "client_id": self.ernie_client_id, + "client_secret": self.ernie_client_secret, + }, + ) + self.access_token = str(resp.json().get("access_token")) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + if not self.access_token: + self._refresh_access_token_with_lock() + text_in_chunks = [ + texts[i : i + self.chunk_size] + for i in range(0, len(texts), self.chunk_size) + ] + lst = [] + for chunk in text_in_chunks: + resp = self._embedding({"input": [text for text in chunk]}) + if resp.get("error_code"): + if resp.get("error_code") == 111: + self._refresh_access_token_with_lock() + resp = self._embedding({"input": [text for text in chunk]}) + else: + raise ValueError(f"Error from Ernie: {resp}") + lst.extend([i["embedding"] for i in resp["data"]]) + return lst + + def embed_query(self, text: str) -> List[float]: + if not self.access_token: + self._refresh_access_token_with_lock() + resp = self._embedding({"input": [text]}) + if resp.get("error_code"): + if resp.get("error_code") == 111: + self._refresh_access_token_with_lock() + resp = self._embedding({"input": [text]}) + else: + raise ValueError(f"Error from Ernie: {resp}") + return resp["data"][0]["embedding"] diff --git a/libs/langchain/tests/integration_tests/embeddings/test_ernie.py b/libs/langchain/tests/integration_tests/embeddings/test_ernie.py new file mode 100644 index 00000000000..9f47f1572fd --- /dev/null +++ b/libs/langchain/tests/integration_tests/embeddings/test_ernie.py @@ -0,0 +1,41 @@ +import pytest + +from langchain.embeddings.ernie import ErnieEmbeddings + + +def test_embedding_documents_1() -> None: + documents = ["foo bar"] + embedding = ErnieEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 384 + + +def test_embedding_documents_2() -> None: + documents = ["foo", "bar"] + embedding = ErnieEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 384 + assert len(output[1]) == 384 + + +def test_embedding_query() -> None: + query = "foo" + embedding = ErnieEmbeddings() + output = embedding.embed_query(query) + assert len(output) == 384 + + +def test_max_chunks() -> None: + documents = [f"text-{i}" for i in range(20)] + embedding = ErnieEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 20 + + +def test_too_many_chunks() -> None: + documents = [f"text-{i}" for i in range(20)] + embedding = ErnieEmbeddings(chunk_size=20) + with pytest.raises(ValueError): + embedding.embed_documents(documents)