diff --git a/docs/extras/integrations/chat/ernie.ipynb b/docs/extras/integrations/chat/ernie.ipynb new file mode 100644 index 00000000000..4b1c4c6db83 --- /dev/null +++ b/docs/extras/integrations/chat/ernie.ipynb @@ -0,0 +1,87 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ERNIE-Bot Chat\n", + "\n", + "This notebook covers how to get started with Ernie chat models." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ErnieBotChat\n", + "from langchain.schema import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "chat = ErnieBotChat(ernie_client_id='YOUR_CLIENT_ID', ernie_client_secret='YOUR_CLIENT_SECRET')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "or you can set `client_id` and `client_secret` in your environment variables\n", + "```bash\n", + "export ERNIE_CLIENT_ID=YOUR_CLIENT_ID\n", + "export ERNIE_CLIENT_SECRET=YOUR_CLIENT_SECRET\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Hello, I am an artificial intelligence language model. My purpose is to help users answer questions or provide information. What can I do for you?', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat([\n", + " HumanMessage(content='hello there, who are you?')\n", + "])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index a115ce816f2..0f26a852b8e 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs. from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.anyscale import ChatAnyscale from langchain.chat_models.azure_openai import AzureChatOpenAI +from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.human import HumanInputChatModel @@ -43,4 +44,5 @@ __all__ = [ "HumanInputChatModel", "ChatAnyscale", "ChatLiteLLM", + "ErnieBotChat", ] diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py new file mode 100644 index 00000000000..c141eaf74dd --- /dev/null +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -0,0 +1,156 @@ +import logging +import threading +from typing import Any, Dict, List, Mapping, Optional + +import requests +from pydantic import root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatMessage, + ChatResult, + HumanMessage, +) +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + +class ErnieBotChat(BaseChatModel): + """ErnieBot Chat large language model. + + ERNIE-Bot is a large language model developed by Baidu, + covering a huge amount of Chinese data. + + To use, you should have the `ernie_client_id` and `ernie_client_secret` set. + + or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`. + + access_token will be automatically generated based on client_id and client_secret, + and will be regenerated after expiration. + + Example: + .. code-block:: python + + from langchain.chat_models import ErnieBotChat + chat = ErnieBotChat() + + """ + + ernie_client_id: Optional[str] = None + ernie_client_secret: Optional[str] = None + access_token: Optional[str] = None + + model_name: str = "ERNIE-Bot-turbo" + + streaming: Optional[bool] = False + top_p: Optional[float] = 0.8 + temperature: Optional[float] = 0.95 + penalty_score: Optional[float] = 1 + + _lock = threading.Lock() + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["ernie_client_id"] = get_from_dict_or_env( + values, + "ernie_client_id", + "ERNIE_CLIENT_ID", + ) + values["ernie_client_secret"] = get_from_dict_or_env( + values, + "ernie_client_secret", + "ERNIE_CLIENT_SECRET", + ) + return values + + def _chat(self, payload: object) -> dict: + base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" + if self.model_name == "ERNIE-Bot-turbo": + url = f"{base_url}/eb-instant" + elif self.model_name == "ERNIE-Bot": + url = f"{base_url}/completions" + else: + raise ValueError(f"Got unknown model_name {self.model_name}") + resp = requests.post( + url, + headers={ + "Content-Type": "application/json", + }, + params={"access_token": self.access_token}, + json=payload, + ) + return resp.json() + + def _refresh_access_token_with_lock(self) -> None: + with self._lock: + logger.debug("Refreshing access token") + base_url: str = "https://aip.baidubce.com/oauth/2.0/token" + resp = requests.post( + base_url, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + params={ + "grant_type": "client_credentials", + "client_id": self.ernie_client_id, + "client_secret": self.ernie_client_secret, + }, + ) + self.access_token = str(resp.json().get("access_token")) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + raise ValueError("`streaming` option currently unsupported.") + + if not self.access_token: + self._refresh_access_token_with_lock() + payload = { + "messages": [_convert_message_to_dict(m) for m in messages], + "top_p": self.top_p, + "temperature": self.temperature, + "penalty_score": self.penalty_score, + } + resp = self._chat(payload) + if resp.get("error_code"): + if resp.get("error_code") == 111: + self._refresh_access_token_with_lock() + resp = self._chat(payload) + else: + raise ValueError(f"Error from ErnieChat api response: {resp}") + return self._create_chat_result(resp) + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [ + ChatGeneration(message=AIMessage(content=response.get("result"))) + ] + token_usage = response.get("usage", {}) + llm_output = {"token_usage": token_usage, "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _llm_type(self) -> str: + return "ernie-chat" diff --git a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py new file mode 100644 index 00000000000..cbaba6debbf --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py @@ -0,0 +1,26 @@ +from langchain.chat_models.ernie import ErnieBotChat +from langchain.schema.messages import AIMessage, HumanMessage + + +def test_chat_ernie_bot() -> None: + chat = ErnieBotChat() + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_ernie_bot_with_model_name() -> None: + chat = ErnieBotChat(model_name="ERNIE-Bot") + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_ernie_bot_with_temperature() -> None: + chat = ErnieBotChat(model_name="ERNIE-Bot", temperature=1.0) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_ernie.py b/libs/langchain/tests/unit_tests/chat_models/test_ernie.py new file mode 100644 index 00000000000..06eb1019615 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_ernie.py @@ -0,0 +1,16 @@ +from langchain.chat_models.ernie import _convert_message_to_dict +from langchain.schema.messages import AIMessage, HumanMessage + + +def test__convert_dict_to_message_human() -> None: + message = HumanMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message = AIMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output