langchain[minor]: add volcengine endpoint as LLM (#13942)

- **Description:** Volc Engine MaaS serves as an enterprise-grade,
large-model service platform designed for developers. You can visit its
homepage at https://www.volcengine.com/docs/82379/1099455 for details.
This change will facilitate developers to integrate quickly with the
platform.
  - **Issue:** None
  - **Dependencies:** volcengine
  - **Tag maintainer:** @baskaryan 
  - **Twitter handle:** @he1v3tica

---------

Co-authored-by: lvzhong <lvzhong@bytedance.com>
This commit is contained in:
h3l 2023-11-30 05:16:42 +08:00 committed by GitHub
parent 1600ebe6c7
commit dbaeb163aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 729 additions and 0 deletions

View File

@ -0,0 +1,177 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "404758628c7b20f6",
"metadata": {
"collapsed": false
},
"source": [
"# Volc Engine Maas\n",
"\n",
"This notebook provides you with a guide on how to get started with volc engine maas chat models."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cd2ebd9d023c4d3",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Install the package\n",
"!pip install volcengine"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "51e7f967cb78f5b7",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:43:37.131292Z",
"start_time": "2023-11-27T10:43:37.127250Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"from langchain.chat_models import VolcEngineMaasChat\n",
"from langchain.schema import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "139667d44689f9e0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:43:49.911867Z",
"start_time": "2023-11-27T10:43:49.908329Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"chat = VolcEngineMaasChat(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")"
]
},
{
"cell_type": "markdown",
"id": "e84ebc4feedcc739",
"metadata": {
"collapsed": false
},
"source": [
"or you can set access_key and secret_key in your environment variables\n",
"```bash\n",
"export VOLC_ACCESSKEY=YOUR_AK\n",
"export VOLC_SECRETKEY=YOUR_SK\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "35da18414ad17aa0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:43:53.101852Z",
"start_time": "2023-11-27T10:43:51.741041Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n为什么鸟儿不会玩电脑游戏\\n\\n因为它们没有翅膀')"
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat([HumanMessage(content=\"给我讲个笑话\")])"
]
},
{
"cell_type": "markdown",
"id": "a55e5a9ed80ec49e",
"metadata": {
"collapsed": false
},
"source": [
"# volc engine maas chat with stream"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "b4e4049980ac68ef",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:43:55.120405Z",
"start_time": "2023-11-27T10:43:55.114707Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"chat = VolcEngineMaasChat(\n",
" volc_engine_maas_ak=\"your ak\",\n",
" volc_engine_maas_sk=\"your sk\",\n",
" streaming=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "fe709a4ffb5c811d",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:43:58.775294Z",
"start_time": "2023-11-27T10:43:56.799401Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n三岁的女儿说她会造句了妈妈让她用“年轻”造句女儿说“妈妈减肥一年轻了好几斤”。')"
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat([HumanMessage(content=\"给我讲个笑话\")])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,124 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "404758628c7b20f6",
"metadata": {
"collapsed": false
},
"source": [
"# Volc Engine Maas\n",
"\n",
"This notebook provides you with a guide on how to get started with Volc Engine's MaaS llm models."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "946db204b33c2ef7",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Install the package\n",
"!pip install volcengine"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "51e7f967cb78f5b7",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:40:26.897649Z",
"start_time": "2023-11-27T10:40:26.552589Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"from langchain.llms import VolcEngineMaasLLM\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema.output_parser import StrOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "139667d44689f9e0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:40:27.938517Z",
"start_time": "2023-11-27T10:40:27.861324Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"llm = VolcEngineMaasLLM(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")"
]
},
{
"cell_type": "markdown",
"id": "e84ebc4feedcc739",
"metadata": {
"collapsed": false
},
"source": [
"or you can set access_key and secret_key in your environment variables\n",
"```bash\n",
"export VOLC_ACCESSKEY=YOUR_AK\n",
"export VOLC_SECRETKEY=YOUR_SK\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "35da18414ad17aa0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-27T10:41:35.528526Z",
"start_time": "2023-11-27T10:41:32.562238Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": "'好的,下面是一个笑话:\\n\\n大学暑假我配了隐形眼镜回家给爷爷说我现在配了隐形眼镜。\\n爷爷让我给他看看于是我用小镊子夹了一片给爷爷看。\\n爷爷看完便准备出门边走还边说“真高级啊还真是隐形眼镜”\\n等爷爷出去后我才发现我刚没夹起来'"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain = PromptTemplate.from_template(\"给我讲个笑话\") | llm | StrOutputParser()\n",
"chain.invoke({})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -43,6 +43,7 @@ from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
from langchain.chat_models.yandex import ChatYandexGPT
__all__ = [
@ -73,4 +74,5 @@ __all__ = [
"ChatBaichuan",
"ChatHunyuan",
"GigaChat",
"VolcEngineMaasChat",
]

View File

@ -0,0 +1,141 @@
from __future__ import annotations
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.llms.volcengine_maas import VolcEngineMaasBase
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
content = _dict.get("choice", {}).get("message", {}).get("content", "")
return AIMessage(
content=content,
)
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
"""volc engine maas hosts a plethora of models.
You can utilize these models through this class.
To use, you should have the ``volcengine`` python package installed.
and set access key and secret key by environment variable or direct pass those
to this class.
access key, secret key are required parameters which you could get help
https://www.volcengine.com/docs/6291/65568
In order to use them, it is necessary to install the 'volcengine' Python package.
The access key and secret key must be set either via environment variables or
passed directly to this class.
access key and secret key are mandatory parameters for which assistance can be
sought at https://www.volcengine.com/docs/6291/65568.
The two methods are as follows:
* Environment Variable
Set the environment variables 'VOLC_ACCESSKEY' and 'VOLC_SECRETKEY' with your
access key and secret key.
* Pass Directly to Class
Example:
.. code-block:: python
from langchain.llms import VolcEngineMaasLLM
model = VolcEngineMaasChat(model="skylark-lite-public",
volc_engine_maas_ak="your_ak",
volc_engine_maas_sk="your_sk")
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "volc-engine-maas-chat"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
def _convert_prompt_msg_params(
self,
messages: List[BaseMessage],
**kwargs: Any,
) -> Dict[str, Any]:
model_req = {
"model": {
"name": self.model,
}
}
if self.model_version is not None:
model_req["model"]["version"] = self.model_version
return {
**model_req,
"messages": [_convert_message_to_dict(message) for message in messages],
"parameters": {**self._default_params, **kwargs},
}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
for res in self.client.stream_chat(params):
if res:
msg = convert_dict_to_message(res)
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
if run_manager:
run_manager.on_llm_new_token(msg.content)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
completion = ""
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
params = self._convert_prompt_msg_params(messages, **kwargs)
res = self.client.chat(params)
msg = convert_dict_to_message(res)
completion = msg.content
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -504,6 +504,12 @@ def _import_yandex_gpt() -> Any:
return YandexGPT
def _import_volcengine_maas() -> Any:
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
return VolcEngineMaasLLM
def __getattr__(name: str) -> Any:
if name == "AI21":
return _import_ai21()
@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any:
return _import_xinference()
elif name == "YandexGPT":
return _import_yandex_gpt()
elif name == "VolcEngineMaasLLM":
return _import_volcengine_maas()
elif name == "type_to_cls_dict":
# for backwards compatibility
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
@ -755,6 +763,7 @@ __all__ = [
"JavelinAIGateway",
"QianfanLLMEndpoint",
"YandexGPT",
"VolcEngineMaasLLM",
]
@ -834,4 +843,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"javelin-ai-gateway": _import_javelin_ai_gateway,
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
"yandex_gpt": _import_yandex_gpt,
"VolcEngineMaasLLM": _import_volcengine_maas(),
}

View File

@ -0,0 +1,176 @@
from __future__ import annotations
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class VolcEngineMaasBase(BaseModel):
"""Base class for VolcEngineMaas models."""
client: Any
volc_engine_maas_ak: Optional[str] = None
"""access key for volc engine"""
volc_engine_maas_sk: Optional[str] = None
"""secret key for volc engine"""
endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com"
"""Endpoint of the VolcEngineMaas LLM."""
region: Optional[str] = "Region"
"""Region of the VolcEngineMaas LLM."""
model: str = "skylark-lite-public"
"""Model name. you could check this model details here
https://www.volcengine.com/docs/82379/1133187
and you could choose other models by change this field"""
model_version: Optional[str] = None
"""Model version. Only used in moonshot large language model.
you could check details here https://www.volcengine.com/docs/82379/1158281"""
top_p: Optional[float] = 0.8
"""Total probability mass of tokens to consider at each step."""
temperature: Optional[float] = 0.95
"""A non-negative float that tunes the degree of randomness in generation."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""model special arguments, you could check detail on model page"""
streaming: bool = False
"""Whether to stream the results."""
connect_timeout: Optional[int] = 60
"""Timeout for connect to volc engine maas endpoint. Default is 60 seconds."""
read_timeout: Optional[int] = 60
"""Timeout for read response from volc engine maas endpoint.
Default is 60 seconds."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY")
sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY")
endpoint = values["endpoint"]
if values["endpoint"] is not None and values["endpoint"] != "":
endpoint = values["endpoint"]
try:
from volcengine.maas import MaasService
maas = MaasService(
endpoint,
values["region"],
connection_timeout=values["connect_timeout"],
socket_timeout=values["read_timeout"],
)
maas.set_ak(ak)
values["volc_engine_maas_ak"] = ak
values["volc_engine_maas_sk"] = sk
maas.set_sk(sk)
values["client"] = maas
except ImportError:
raise ImportError(
"volcengine package not found, please install it with "
"`pip install volcengine`"
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling VolcEngineMaas API."""
normal_params = {
"top_p": self.top_p,
"temperature": self.temperature,
}
return {**normal_params, **self.model_kwargs}
class VolcEngineMaasLLM(LLM, VolcEngineMaasBase):
"""volc engine maas hosts a plethora of models.
You can utilize these models through this class.
To use, you should have the ``volcengine`` python package installed.
and set access key and secret key by environment variable or direct pass those to
this class.
access key, secret key are required parameters which you could get help
https://www.volcengine.com/docs/6291/65568
In order to use them, it is necessary to install the 'volcengine' Python package.
The access key and secret key must be set either via environment variables or
passed directly to this class.
access key and secret key are mandatory parameters for which assistance can be
sought at https://www.volcengine.com/docs/6291/65568.
Example:
.. code-block:: python
from langchain.llms import VolcEngineMaasLLM
model = VolcEngineMaasLLM(model="skylark-lite-public",
volc_engine_maas_ak="your_ak",
volc_engine_maas_sk="your_sk")
"""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "volc-engine-maas-llm"
def _convert_prompt_msg_params(
self,
prompt: str,
**kwargs: Any,
) -> dict:
model_req = {
"model": {
"name": self.model,
}
}
if self.model_version is not None:
model_req["model"]["version"] = self.model_version
return {
**model_req,
"messages": [{"role": "user", "content": prompt}],
"parameters": {**self._default_params, **kwargs},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response = self.client.chat(params)
return response.get("choice", {}).get("message", {}).get("content", "")
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
for res in self.client.stream_chat(params):
if res:
chunk = GenerationChunk(
text=res.get("choice", {}).get("message", {}).get("content", "")
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

View File

@ -0,0 +1,69 @@
"""Test volc engine maas chat model."""
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_default_call() -> None:
"""Test valid chat call to volc engine."""
chat = VolcEngineMaasChat()
response = chat(messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
chat = VolcEngineMaasChat()
response = chat(
messages=[
HumanMessage(content="Hello"),
AIMessage(content="Hello!"),
HumanMessage(content="How are you?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_stream() -> None:
"""Test that stream works."""
chat = VolcEngineMaasChat(streaming=True)
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
response = chat(
messages=[
HumanMessage(content="Hello"),
AIMessage(content="Hello!"),
HumanMessage(content="How are you?"),
],
stream=True,
callbacks=callback_manager,
)
assert callback_handler.llm_streams > 0
assert isinstance(response.content, str)
def test_multiple_messages() -> None:
"""Tests multiple messages works."""
chat = VolcEngineMaasChat()
message = HumanMessage(content="Hi, how are you?")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

View File

@ -0,0 +1,28 @@
"""Test volc engine maas LLM model."""
from typing import Generator
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
from langchain.schema import LLMResult
def test_default_call() -> None:
"""Test valid call to volc engine."""
llm = VolcEngineMaasLLM()
output = llm("tell me a joke")
assert isinstance(output, str)
def test_generate() -> None:
"""Test valid call to volc engine."""
llm = VolcEngineMaasLLM()
output = llm.generate(["tell me a joke"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
def test_generate_stream() -> None:
"""Test valid call to volc engine."""
llm = VolcEngineMaasLLM(streaming=True)
output = llm.stream("tell me a joke")
assert isinstance(output, Generator)

View File

@ -28,6 +28,7 @@ EXPECTED_ALL = [
"ChatBaichuan",
"ChatHunyuan",
"GigaChat",
"VolcEngineMaasChat",
]

View File

@ -81,6 +81,7 @@ EXPECT_ALL = [
"JavelinAIGateway",
"QianfanLLMEndpoint",
"YandexGPT",
"VolcEngineMaasLLM",
]