From dbaeb163aaddc17bc76cbe1b34fb76207d329ba2 Mon Sep 17 00:00:00 2001 From: h3l Date: Thu, 30 Nov 2023 05:16:42 +0800 Subject: [PATCH] langchain[minor]: add volcengine endpoint as LLM (#13942) - **Description:** Volc Engine MaaS serves as an enterprise-grade, large-model service platform designed for developers. You can visit its homepage at https://www.volcengine.com/docs/82379/1099455 for details. This change will facilitate developers to integrate quickly with the platform. - **Issue:** None - **Dependencies:** volcengine - **Tag maintainer:** @baskaryan - **Twitter handle:** @he1v3tica --------- Co-authored-by: lvzhong --- .../integrations/chat/volcengine_maas.ipynb | 177 ++++++++++++++++++ .../integrations/llms/volcengine_maas.ipynb | 124 ++++++++++++ .../langchain/chat_models/__init__.py | 2 + .../langchain/chat_models/volcengine_maas.py | 141 ++++++++++++++ libs/langchain/langchain/llms/__init__.py | 10 + .../langchain/llms/volcengine_maas.py | 176 +++++++++++++++++ .../chat_models/test_volcengine_maas.py | 69 +++++++ .../llms/test_volcengine_maas.py | 28 +++ .../unit_tests/chat_models/test_imports.py | 1 + .../tests/unit_tests/llms/test_imports.py | 1 + 10 files changed, 729 insertions(+) create mode 100644 docs/docs/integrations/chat/volcengine_maas.ipynb create mode 100644 docs/docs/integrations/llms/volcengine_maas.ipynb create mode 100644 libs/langchain/langchain/chat_models/volcengine_maas.py create mode 100644 libs/langchain/langchain/llms/volcengine_maas.py create mode 100644 libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py create mode 100644 libs/langchain/tests/integration_tests/llms/test_volcengine_maas.py diff --git a/docs/docs/integrations/chat/volcengine_maas.ipynb b/docs/docs/integrations/chat/volcengine_maas.ipynb new file mode 100644 index 00000000000..32dd0c16d09 --- /dev/null +++ b/docs/docs/integrations/chat/volcengine_maas.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "404758628c7b20f6", + "metadata": { + "collapsed": false + }, + "source": [ + "# Volc Engine Maas\n", + "\n", + "This notebook provides you with a guide on how to get started with volc engine maas chat models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cd2ebd9d023c4d3", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Install the package\n", + "!pip install volcengine" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "51e7f967cb78f5b7", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:43:37.131292Z", + "start_time": "2023-11-27T10:43:37.127250Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain.chat_models import VolcEngineMaasChat\n", + "from langchain.schema import HumanMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "139667d44689f9e0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:43:49.911867Z", + "start_time": "2023-11-27T10:43:49.908329Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "chat = VolcEngineMaasChat(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")" + ] + }, + { + "cell_type": "markdown", + "id": "e84ebc4feedcc739", + "metadata": { + "collapsed": false + }, + "source": [ + "or you can set access_key and secret_key in your environment variables\n", + "```bash\n", + "export VOLC_ACCESSKEY=YOUR_AK\n", + "export VOLC_SECRETKEY=YOUR_SK\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "35da18414ad17aa0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:43:53.101852Z", + "start_time": "2023-11-27T10:43:51.741041Z" + }, + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n为什么鸟儿不会玩电脑游戏?\\n\\n因为它们没有翅膀!')" + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat([HumanMessage(content=\"给我讲个笑话\")])" + ] + }, + { + "cell_type": "markdown", + "id": "a55e5a9ed80ec49e", + "metadata": { + "collapsed": false + }, + "source": [ + "# volc engine maas chat with stream" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "b4e4049980ac68ef", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:43:55.120405Z", + "start_time": "2023-11-27T10:43:55.114707Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "chat = VolcEngineMaasChat(\n", + " volc_engine_maas_ak=\"your ak\",\n", + " volc_engine_maas_sk=\"your sk\",\n", + " streaming=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "fe709a4ffb5c811d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:43:58.775294Z", + "start_time": "2023-11-27T10:43:56.799401Z" + }, + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n三岁的女儿说她会造句了,妈妈让她用“年轻”造句,女儿说:“妈妈减肥,一年轻了好几斤”。')" + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat([HumanMessage(content=\"给我讲个笑话\")])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/integrations/llms/volcengine_maas.ipynb b/docs/docs/integrations/llms/volcengine_maas.ipynb new file mode 100644 index 00000000000..c34cde60c0b --- /dev/null +++ b/docs/docs/integrations/llms/volcengine_maas.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "404758628c7b20f6", + "metadata": { + "collapsed": false + }, + "source": [ + "# Volc Engine Maas\n", + "\n", + "This notebook provides you with a guide on how to get started with Volc Engine's MaaS llm models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "946db204b33c2ef7", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Install the package\n", + "!pip install volcengine" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "51e7f967cb78f5b7", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:40:26.897649Z", + "start_time": "2023-11-27T10:40:26.552589Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "from langchain.llms import VolcEngineMaasLLM\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.schema.output_parser import StrOutputParser" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "139667d44689f9e0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:40:27.938517Z", + "start_time": "2023-11-27T10:40:27.861324Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "llm = VolcEngineMaasLLM(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")" + ] + }, + { + "cell_type": "markdown", + "id": "e84ebc4feedcc739", + "metadata": { + "collapsed": false + }, + "source": [ + "or you can set access_key and secret_key in your environment variables\n", + "```bash\n", + "export VOLC_ACCESSKEY=YOUR_AK\n", + "export VOLC_SECRETKEY=YOUR_SK\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "35da18414ad17aa0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-27T10:41:35.528526Z", + "start_time": "2023-11-27T10:41:32.562238Z" + }, + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": "'好的,下面是一个笑话:\\n\\n大学暑假我配了隐形眼镜,回家给爷爷说,我现在配了隐形眼镜。\\n爷爷让我给他看看,于是,我用小镊子夹了一片给爷爷看。\\n爷爷看完便准备出门,边走还边说:“真高级啊,还真是隐形眼镜!”\\n等爷爷出去后我才发现,我刚没夹起来!'" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = PromptTemplate.from_template(\"给我讲个笑话\") | llm | StrOutputParser()\n", + "chain.invoke({})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index 2471fe1c3f2..b2db80b470f 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -43,6 +43,7 @@ from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.chat_models.vertexai import ChatVertexAI +from langchain.chat_models.volcengine_maas import VolcEngineMaasChat from langchain.chat_models.yandex import ChatYandexGPT __all__ = [ @@ -73,4 +74,5 @@ __all__ = [ "ChatBaichuan", "ChatHunyuan", "GigaChat", + "VolcEngineMaasChat", ] diff --git a/libs/langchain/langchain/chat_models/volcengine_maas.py b/libs/langchain/langchain/chat_models/volcengine_maas.py new file mode 100644 index 00000000000..586febe38c5 --- /dev/null +++ b/libs/langchain/langchain/chat_models/volcengine_maas.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterator, List, Mapping, Optional + +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel +from langchain.llms.volcengine_maas import VolcEngineMaasBase + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + +def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: + content = _dict.get("choice", {}).get("message", {}).get("content", "") + return AIMessage( + content=content, + ) + + +class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase): + + """volc engine maas hosts a plethora of models. + You can utilize these models through this class. + + To use, you should have the ``volcengine`` python package installed. + and set access key and secret key by environment variable or direct pass those + to this class. + access key, secret key are required parameters which you could get help + https://www.volcengine.com/docs/6291/65568 + + In order to use them, it is necessary to install the 'volcengine' Python package. + The access key and secret key must be set either via environment variables or + passed directly to this class. + access key and secret key are mandatory parameters for which assistance can be + sought at https://www.volcengine.com/docs/6291/65568. + + The two methods are as follows: + * Environment Variable + Set the environment variables 'VOLC_ACCESSKEY' and 'VOLC_SECRETKEY' with your + access key and secret key. + + * Pass Directly to Class + Example: + .. code-block:: python + + from langchain.llms import VolcEngineMaasLLM + model = VolcEngineMaasChat(model="skylark-lite-public", + volc_engine_maas_ak="your_ak", + volc_engine_maas_sk="your_sk") + """ + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "volc-engine-maas-chat" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by Langchain.""" + return True + + @property + def _identifying_params(self) -> Dict[str, Any]: + return { + **{"endpoint": self.endpoint, "model": self.model}, + **super()._identifying_params, + } + + def _convert_prompt_msg_params( + self, + messages: List[BaseMessage], + **kwargs: Any, + ) -> Dict[str, Any]: + model_req = { + "model": { + "name": self.model, + } + } + if self.model_version is not None: + model_req["model"]["version"] = self.model_version + return { + **model_req, + "messages": [_convert_message_to_dict(message) for message in messages], + "parameters": {**self._default_params, **kwargs}, + } + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + params = self._convert_prompt_msg_params(messages, **kwargs) + for res in self.client.stream_chat(params): + if res: + msg = convert_dict_to_message(res) + yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content)) + if run_manager: + run_manager.on_llm_new_token(msg.content) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + completion = "" + if self.streaming: + for chunk in self._stream(messages, stop, run_manager, **kwargs): + completion += chunk.text + else: + params = self._convert_prompt_msg_params(messages, **kwargs) + res = self.client.chat(params) + msg = convert_dict_to_message(res) + completion = msg.content + + message = AIMessage(content=completion) + return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index 351a75f3786..407849ee624 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -504,6 +504,12 @@ def _import_yandex_gpt() -> Any: return YandexGPT +def _import_volcengine_maas() -> Any: + from langchain.llms.volcengine_maas import VolcEngineMaasLLM + + return VolcEngineMaasLLM + + def __getattr__(name: str) -> Any: if name == "AI21": return _import_ai21() @@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any: return _import_xinference() elif name == "YandexGPT": return _import_yandex_gpt() + elif name == "VolcEngineMaasLLM": + return _import_volcengine_maas() elif name == "type_to_cls_dict": # for backwards compatibility type_to_cls_dict: Dict[str, Type[BaseLLM]] = { @@ -755,6 +763,7 @@ __all__ = [ "JavelinAIGateway", "QianfanLLMEndpoint", "YandexGPT", + "VolcEngineMaasLLM", ] @@ -834,4 +843,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "javelin-ai-gateway": _import_javelin_ai_gateway, "qianfan_endpoint": _import_baidu_qianfan_endpoint, "yandex_gpt": _import_yandex_gpt, + "VolcEngineMaasLLM": _import_volcengine_maas(), } diff --git a/libs/langchain/langchain/llms/volcengine_maas.py b/libs/langchain/langchain/llms/volcengine_maas.py new file mode 100644 index 00000000000..9dabb20367b --- /dev/null +++ b/libs/langchain/langchain/llms/volcengine_maas.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterator, List, Optional + +from langchain_core.outputs import GenerationChunk +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.utils import get_from_dict_or_env + + +class VolcEngineMaasBase(BaseModel): + """Base class for VolcEngineMaas models.""" + + client: Any + + volc_engine_maas_ak: Optional[str] = None + """access key for volc engine""" + volc_engine_maas_sk: Optional[str] = None + """secret key for volc engine""" + + endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com" + """Endpoint of the VolcEngineMaas LLM.""" + + region: Optional[str] = "Region" + """Region of the VolcEngineMaas LLM.""" + + model: str = "skylark-lite-public" + """Model name. you could check this model details here + https://www.volcengine.com/docs/82379/1133187 + and you could choose other models by change this field""" + model_version: Optional[str] = None + """Model version. Only used in moonshot large language model. + you could check details here https://www.volcengine.com/docs/82379/1158281""" + + top_p: Optional[float] = 0.8 + """Total probability mass of tokens to consider at each step.""" + + temperature: Optional[float] = 0.95 + """A non-negative float that tunes the degree of randomness in generation.""" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """model special arguments, you could check detail on model page""" + + streaming: bool = False + """Whether to stream the results.""" + + connect_timeout: Optional[int] = 60 + """Timeout for connect to volc engine maas endpoint. Default is 60 seconds.""" + + read_timeout: Optional[int] = 60 + """Timeout for read response from volc engine maas endpoint. + Default is 60 seconds.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY") + sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY") + endpoint = values["endpoint"] + if values["endpoint"] is not None and values["endpoint"] != "": + endpoint = values["endpoint"] + try: + from volcengine.maas import MaasService + + maas = MaasService( + endpoint, + values["region"], + connection_timeout=values["connect_timeout"], + socket_timeout=values["read_timeout"], + ) + maas.set_ak(ak) + values["volc_engine_maas_ak"] = ak + values["volc_engine_maas_sk"] = sk + maas.set_sk(sk) + values["client"] = maas + except ImportError: + raise ImportError( + "volcengine package not found, please install it with " + "`pip install volcengine`" + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling VolcEngineMaas API.""" + normal_params = { + "top_p": self.top_p, + "temperature": self.temperature, + } + + return {**normal_params, **self.model_kwargs} + + +class VolcEngineMaasLLM(LLM, VolcEngineMaasBase): + """volc engine maas hosts a plethora of models. + You can utilize these models through this class. + + To use, you should have the ``volcengine`` python package installed. + and set access key and secret key by environment variable or direct pass those to + this class. + access key, secret key are required parameters which you could get help + https://www.volcengine.com/docs/6291/65568 + + In order to use them, it is necessary to install the 'volcengine' Python package. + The access key and secret key must be set either via environment variables or + passed directly to this class. + access key and secret key are mandatory parameters for which assistance can be + sought at https://www.volcengine.com/docs/6291/65568. + + Example: + .. code-block:: python + + from langchain.llms import VolcEngineMaasLLM + model = VolcEngineMaasLLM(model="skylark-lite-public", + volc_engine_maas_ak="your_ak", + volc_engine_maas_sk="your_sk") + """ + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "volc-engine-maas-llm" + + def _convert_prompt_msg_params( + self, + prompt: str, + **kwargs: Any, + ) -> dict: + model_req = { + "model": { + "name": self.model, + } + } + if self.model_version is not None: + model_req["model"]["version"] = self.model_version + + return { + **model_req, + "messages": [{"role": "user", "content": prompt}], + "parameters": {**self._default_params, **kwargs}, + } + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if self.streaming: + completion = "" + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + completion += chunk.text + return completion + params = self._convert_prompt_msg_params(prompt, **kwargs) + response = self.client.chat(params) + + return response.get("choice", {}).get("message", {}).get("content", "") + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + params = self._convert_prompt_msg_params(prompt, **kwargs) + for res in self.client.stream_chat(params): + if res: + chunk = GenerationChunk( + text=res.get("choice", {}).get("message", {}).get("content", "") + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py b/libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py new file mode 100644 index 00000000000..31800b9a4e7 --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_volcengine_maas.py @@ -0,0 +1,69 @@ +"""Test volc engine maas chat model.""" + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.volcengine_maas import VolcEngineMaasChat +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + HumanMessage, + LLMResult, +) +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +def test_default_call() -> None: + """Test valid chat call to volc engine.""" + chat = VolcEngineMaasChat() + response = chat(messages=[HumanMessage(content="Hello")]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_multiple_history() -> None: + """Tests multiple history works.""" + chat = VolcEngineMaasChat() + + response = chat( + messages=[ + HumanMessage(content="Hello"), + AIMessage(content="Hello!"), + HumanMessage(content="How are you?"), + ] + ) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_stream() -> None: + """Test that stream works.""" + chat = VolcEngineMaasChat(streaming=True) + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + response = chat( + messages=[ + HumanMessage(content="Hello"), + AIMessage(content="Hello!"), + HumanMessage(content="How are you?"), + ], + stream=True, + callbacks=callback_manager, + ) + assert callback_handler.llm_streams > 0 + assert isinstance(response.content, str) + + +def test_multiple_messages() -> None: + """Tests multiple messages works.""" + chat = VolcEngineMaasChat() + message = HumanMessage(content="Hi, how are you?") + response = chat.generate([[message], [message]]) + + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content diff --git a/libs/langchain/tests/integration_tests/llms/test_volcengine_maas.py b/libs/langchain/tests/integration_tests/llms/test_volcengine_maas.py new file mode 100644 index 00000000000..22a357034d4 --- /dev/null +++ b/libs/langchain/tests/integration_tests/llms/test_volcengine_maas.py @@ -0,0 +1,28 @@ +"""Test volc engine maas LLM model.""" + +from typing import Generator + +from langchain.llms.volcengine_maas import VolcEngineMaasLLM +from langchain.schema import LLMResult + + +def test_default_call() -> None: + """Test valid call to volc engine.""" + llm = VolcEngineMaasLLM() + output = llm("tell me a joke") + assert isinstance(output, str) + + +def test_generate() -> None: + """Test valid call to volc engine.""" + llm = VolcEngineMaasLLM() + output = llm.generate(["tell me a joke"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) + + +def test_generate_stream() -> None: + """Test valid call to volc engine.""" + llm = VolcEngineMaasLLM(streaming=True) + output = llm.stream("tell me a joke") + assert isinstance(output, Generator) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_imports.py b/libs/langchain/tests/unit_tests/chat_models/test_imports.py index c78abe603c9..2adb8ec65a5 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_imports.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_imports.py @@ -28,6 +28,7 @@ EXPECTED_ALL = [ "ChatBaichuan", "ChatHunyuan", "GigaChat", + "VolcEngineMaasChat", ] diff --git a/libs/langchain/tests/unit_tests/llms/test_imports.py b/libs/langchain/tests/unit_tests/llms/test_imports.py index 1ba5c6cb01d..6ab0e1f036f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_imports.py +++ b/libs/langchain/tests/unit_tests/llms/test_imports.py @@ -81,6 +81,7 @@ EXPECT_ALL = [ "JavelinAIGateway", "QianfanLLMEndpoint", "YandexGPT", + "VolcEngineMaasLLM", ]