mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
langchain[minor]: add volcengine endpoint as LLM (#13942)
- **Description:** Volc Engine MaaS serves as an enterprise-grade, large-model service platform designed for developers. You can visit its homepage at https://www.volcengine.com/docs/82379/1099455 for details. This change will facilitate developers to integrate quickly with the platform. - **Issue:** None - **Dependencies:** volcengine - **Tag maintainer:** @baskaryan - **Twitter handle:** @he1v3tica --------- Co-authored-by: lvzhong <lvzhong@bytedance.com>
This commit is contained in:
parent
1600ebe6c7
commit
dbaeb163aa
177
docs/docs/integrations/chat/volcengine_maas.ipynb
Normal file
177
docs/docs/integrations/chat/volcengine_maas.ipynb
Normal file
@ -0,0 +1,177 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "404758628c7b20f6",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"# Volc Engine Maas\n",
|
||||
"\n",
|
||||
"This notebook provides you with a guide on how to get started with volc engine maas chat models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2cd2ebd9d023c4d3",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install the package\n",
|
||||
"!pip install volcengine"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "51e7f967cb78f5b7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:43:37.131292Z",
|
||||
"start_time": "2023-11-27T10:43:37.127250Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import VolcEngineMaasChat\n",
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "139667d44689f9e0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:43:49.911867Z",
|
||||
"start_time": "2023-11-27T10:43:49.908329Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = VolcEngineMaasChat(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e84ebc4feedcc739",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"or you can set access_key and secret_key in your environment variables\n",
|
||||
"```bash\n",
|
||||
"export VOLC_ACCESSKEY=YOUR_AK\n",
|
||||
"export VOLC_SECRETKEY=YOUR_SK\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "35da18414ad17aa0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:43:53.101852Z",
|
||||
"start_time": "2023-11-27T10:43:51.741041Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n为什么鸟儿不会玩电脑游戏?\\n\\n因为它们没有翅膀!')"
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat([HumanMessage(content=\"给我讲个笑话\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a55e5a9ed80ec49e",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"# volc engine maas chat with stream"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "b4e4049980ac68ef",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:43:55.120405Z",
|
||||
"start_time": "2023-11-27T10:43:55.114707Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = VolcEngineMaasChat(\n",
|
||||
" volc_engine_maas_ak=\"your ak\",\n",
|
||||
" volc_engine_maas_sk=\"your sk\",\n",
|
||||
" streaming=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"id": "fe709a4ffb5c811d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:43:58.775294Z",
|
||||
"start_time": "2023-11-27T10:43:56.799401Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n三岁的女儿说她会造句了,妈妈让她用“年轻”造句,女儿说:“妈妈减肥,一年轻了好几斤”。')"
|
||||
},
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat([HumanMessage(content=\"给我讲个笑话\")])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
124
docs/docs/integrations/llms/volcengine_maas.ipynb
Normal file
124
docs/docs/integrations/llms/volcengine_maas.ipynb
Normal file
@ -0,0 +1,124 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "404758628c7b20f6",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"# Volc Engine Maas\n",
|
||||
"\n",
|
||||
"This notebook provides you with a guide on how to get started with Volc Engine's MaaS llm models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "946db204b33c2ef7",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install the package\n",
|
||||
"!pip install volcengine"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "51e7f967cb78f5b7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:40:26.897649Z",
|
||||
"start_time": "2023-11-27T10:40:26.552589Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import VolcEngineMaasLLM\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "139667d44689f9e0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:40:27.938517Z",
|
||||
"start_time": "2023-11-27T10:40:27.861324Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = VolcEngineMaasLLM(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e84ebc4feedcc739",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"or you can set access_key and secret_key in your environment variables\n",
|
||||
"```bash\n",
|
||||
"export VOLC_ACCESSKEY=YOUR_AK\n",
|
||||
"export VOLC_SECRETKEY=YOUR_SK\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "35da18414ad17aa0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-11-27T10:41:35.528526Z",
|
||||
"start_time": "2023-11-27T10:41:32.562238Z"
|
||||
},
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "'好的,下面是一个笑话:\\n\\n大学暑假我配了隐形眼镜,回家给爷爷说,我现在配了隐形眼镜。\\n爷爷让我给他看看,于是,我用小镊子夹了一片给爷爷看。\\n爷爷看完便准备出门,边走还边说:“真高级啊,还真是隐形眼镜!”\\n等爷爷出去后我才发现,我刚没夹起来!'"
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain = PromptTemplate.from_template(\"给我讲个笑话\") | llm | StrOutputParser()\n",
|
||||
"chain.invoke({})"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -43,6 +43,7 @@ from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
from langchain.chat_models.vertexai import ChatVertexAI
|
||||
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||
from langchain.chat_models.yandex import ChatYandexGPT
|
||||
|
||||
__all__ = [
|
||||
@ -73,4 +74,5 @@ __all__ = [
|
||||
"ChatBaichuan",
|
||||
"ChatHunyuan",
|
||||
"GigaChat",
|
||||
"VolcEngineMaasChat",
|
||||
]
|
||||
|
141
libs/langchain/langchain/chat_models/volcengine_maas.py
Normal file
141
libs/langchain/langchain/chat_models/volcengine_maas.py
Normal file
@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.volcengine_maas import VolcEngineMaasBase
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
||||
return AIMessage(
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||
|
||||
"""volc engine maas hosts a plethora of models.
|
||||
You can utilize these models through this class.
|
||||
|
||||
To use, you should have the ``volcengine`` python package installed.
|
||||
and set access key and secret key by environment variable or direct pass those
|
||||
to this class.
|
||||
access key, secret key are required parameters which you could get help
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
|
||||
In order to use them, it is necessary to install the 'volcengine' Python package.
|
||||
The access key and secret key must be set either via environment variables or
|
||||
passed directly to this class.
|
||||
access key and secret key are mandatory parameters for which assistance can be
|
||||
sought at https://www.volcengine.com/docs/6291/65568.
|
||||
|
||||
The two methods are as follows:
|
||||
* Environment Variable
|
||||
Set the environment variables 'VOLC_ACCESSKEY' and 'VOLC_SECRETKEY' with your
|
||||
access key and secret key.
|
||||
|
||||
* Pass Directly to Class
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import VolcEngineMaasLLM
|
||||
model = VolcEngineMaasChat(model="skylark-lite-public",
|
||||
volc_engine_maas_ak="your_ak",
|
||||
volc_engine_maas_sk="your_sk")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "volc-engine-maas-chat"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
model_req = {
|
||||
"model": {
|
||||
"name": self.model,
|
||||
}
|
||||
}
|
||||
if self.model_version is not None:
|
||||
model_req["model"]["version"] = self.model_version
|
||||
return {
|
||||
**model_req,
|
||||
"messages": [_convert_message_to_dict(message) for message in messages],
|
||||
"parameters": {**self._default_params, **kwargs},
|
||||
}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
for res in self.client.stream_chat(params):
|
||||
if res:
|
||||
msg = convert_dict_to_message(res)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(msg.content)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
completion = ""
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
res = self.client.chat(params)
|
||||
msg = convert_dict_to_message(res)
|
||||
completion = msg.content
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
@ -504,6 +504,12 @@ def _import_yandex_gpt() -> Any:
|
||||
return YandexGPT
|
||||
|
||||
|
||||
def _import_volcengine_maas() -> Any:
|
||||
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
|
||||
return VolcEngineMaasLLM
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "AI21":
|
||||
return _import_ai21()
|
||||
@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_xinference()
|
||||
elif name == "YandexGPT":
|
||||
return _import_yandex_gpt()
|
||||
elif name == "VolcEngineMaasLLM":
|
||||
return _import_volcengine_maas()
|
||||
elif name == "type_to_cls_dict":
|
||||
# for backwards compatibility
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
@ -755,6 +763,7 @@ __all__ = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"VolcEngineMaasLLM",
|
||||
]
|
||||
|
||||
|
||||
@ -834,4 +843,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||
"yandex_gpt": _import_yandex_gpt,
|
||||
"VolcEngineMaasLLM": _import_volcengine_maas(),
|
||||
}
|
||||
|
176
libs/langchain/langchain/llms/volcengine_maas.py
Normal file
176
libs/langchain/langchain/llms/volcengine_maas.py
Normal file
@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class VolcEngineMaasBase(BaseModel):
|
||||
"""Base class for VolcEngineMaas models."""
|
||||
|
||||
client: Any
|
||||
|
||||
volc_engine_maas_ak: Optional[str] = None
|
||||
"""access key for volc engine"""
|
||||
volc_engine_maas_sk: Optional[str] = None
|
||||
"""secret key for volc engine"""
|
||||
|
||||
endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com"
|
||||
"""Endpoint of the VolcEngineMaas LLM."""
|
||||
|
||||
region: Optional[str] = "Region"
|
||||
"""Region of the VolcEngineMaas LLM."""
|
||||
|
||||
model: str = "skylark-lite-public"
|
||||
"""Model name. you could check this model details here
|
||||
https://www.volcengine.com/docs/82379/1133187
|
||||
and you could choose other models by change this field"""
|
||||
model_version: Optional[str] = None
|
||||
"""Model version. Only used in moonshot large language model.
|
||||
you could check details here https://www.volcengine.com/docs/82379/1158281"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
temperature: Optional[float] = 0.95
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""model special arguments, you could check detail on model page"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
connect_timeout: Optional[int] = 60
|
||||
"""Timeout for connect to volc engine maas endpoint. Default is 60 seconds."""
|
||||
|
||||
read_timeout: Optional[int] = 60
|
||||
"""Timeout for read response from volc engine maas endpoint.
|
||||
Default is 60 seconds."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY")
|
||||
sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY")
|
||||
endpoint = values["endpoint"]
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
endpoint = values["endpoint"]
|
||||
try:
|
||||
from volcengine.maas import MaasService
|
||||
|
||||
maas = MaasService(
|
||||
endpoint,
|
||||
values["region"],
|
||||
connection_timeout=values["connect_timeout"],
|
||||
socket_timeout=values["read_timeout"],
|
||||
)
|
||||
maas.set_ak(ak)
|
||||
values["volc_engine_maas_ak"] = ak
|
||||
values["volc_engine_maas_sk"] = sk
|
||||
maas.set_sk(sk)
|
||||
values["client"] = maas
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"volcengine package not found, please install it with "
|
||||
"`pip install volcengine`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling VolcEngineMaas API."""
|
||||
normal_params = {
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
|
||||
class VolcEngineMaasLLM(LLM, VolcEngineMaasBase):
|
||||
"""volc engine maas hosts a plethora of models.
|
||||
You can utilize these models through this class.
|
||||
|
||||
To use, you should have the ``volcengine`` python package installed.
|
||||
and set access key and secret key by environment variable or direct pass those to
|
||||
this class.
|
||||
access key, secret key are required parameters which you could get help
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
|
||||
In order to use them, it is necessary to install the 'volcengine' Python package.
|
||||
The access key and secret key must be set either via environment variables or
|
||||
passed directly to this class.
|
||||
access key and secret key are mandatory parameters for which assistance can be
|
||||
sought at https://www.volcengine.com/docs/6291/65568.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import VolcEngineMaasLLM
|
||||
model = VolcEngineMaasLLM(model="skylark-lite-public",
|
||||
volc_engine_maas_ak="your_ak",
|
||||
volc_engine_maas_sk="your_sk")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "volc-engine-maas-llm"
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
model_req = {
|
||||
"model": {
|
||||
"name": self.model,
|
||||
}
|
||||
}
|
||||
if self.model_version is not None:
|
||||
model_req["model"]["version"] = self.model_version
|
||||
|
||||
return {
|
||||
**model_req,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"parameters": {**self._default_params, **kwargs},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
response = self.client.chat(params)
|
||||
|
||||
return response.get("choice", {}).get("message", {}).get("content", "")
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
for res in self.client.stream_chat(params):
|
||||
if res:
|
||||
chunk = GenerationChunk(
|
||||
text=res.get("choice", {}).get("message", {}).get("content", "")
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
@ -0,0 +1,69 @@
|
||||
"""Test volc engine maas chat model."""
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test valid chat call to volc engine."""
|
||||
chat = VolcEngineMaasChat()
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = VolcEngineMaasChat()
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = VolcEngineMaasChat(streaming=True)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = VolcEngineMaasChat()
|
||||
message = HumanMessage(content="Hi, how are you?")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@ -0,0 +1,28 @@
|
||||
"""Test volc engine maas LLM model."""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM()
|
||||
output = llm("tell me a joke")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_generate() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM()
|
||||
output = llm.generate(["tell me a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
|
||||
|
||||
def test_generate_stream() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM(streaming=True)
|
||||
output = llm.stream("tell me a joke")
|
||||
assert isinstance(output, Generator)
|
@ -28,6 +28,7 @@ EXPECTED_ALL = [
|
||||
"ChatBaichuan",
|
||||
"ChatHunyuan",
|
||||
"GigaChat",
|
||||
"VolcEngineMaasChat",
|
||||
]
|
||||
|
||||
|
||||
|
@ -81,6 +81,7 @@ EXPECT_ALL = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"VolcEngineMaasLLM",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user