mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
langchain[minor]: add volcengine endpoint as LLM (#13942)
- **Description:** Volc Engine MaaS serves as an enterprise-grade, large-model service platform designed for developers. You can visit its homepage at https://www.volcengine.com/docs/82379/1099455 for details. This change will facilitate developers to integrate quickly with the platform. - **Issue:** None - **Dependencies:** volcengine - **Tag maintainer:** @baskaryan - **Twitter handle:** @he1v3tica --------- Co-authored-by: lvzhong <lvzhong@bytedance.com>
This commit is contained in:
parent
1600ebe6c7
commit
dbaeb163aa
177
docs/docs/integrations/chat/volcengine_maas.ipynb
Normal file
177
docs/docs/integrations/chat/volcengine_maas.ipynb
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "404758628c7b20f6",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Volc Engine Maas\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook provides you with a guide on how to get started with volc engine maas chat models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2cd2ebd9d023c4d3",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Install the package\n",
|
||||||
|
"!pip install volcengine"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"id": "51e7f967cb78f5b7",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:43:37.131292Z",
|
||||||
|
"start_time": "2023-11-27T10:43:37.127250Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import VolcEngineMaasChat\n",
|
||||||
|
"from langchain.schema import HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"id": "139667d44689f9e0",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:43:49.911867Z",
|
||||||
|
"start_time": "2023-11-27T10:43:49.908329Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = VolcEngineMaasChat(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e84ebc4feedcc739",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"or you can set access_key and secret_key in your environment variables\n",
|
||||||
|
"```bash\n",
|
||||||
|
"export VOLC_ACCESSKEY=YOUR_AK\n",
|
||||||
|
"export VOLC_SECRETKEY=YOUR_SK\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"id": "35da18414ad17aa0",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:43:53.101852Z",
|
||||||
|
"start_time": "2023-11-27T10:43:51.741041Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n为什么鸟儿不会玩电脑游戏?\\n\\n因为它们没有翅膀!')"
|
||||||
|
},
|
||||||
|
"execution_count": 26,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat([HumanMessage(content=\"给我讲个笑话\")])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a55e5a9ed80ec49e",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# volc engine maas chat with stream"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"id": "b4e4049980ac68ef",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:43:55.120405Z",
|
||||||
|
"start_time": "2023-11-27T10:43:55.114707Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = VolcEngineMaasChat(\n",
|
||||||
|
" volc_engine_maas_ak=\"your ak\",\n",
|
||||||
|
" volc_engine_maas_sk=\"your sk\",\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"id": "fe709a4ffb5c811d",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:43:58.775294Z",
|
||||||
|
"start_time": "2023-11-27T10:43:56.799401Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "AIMessage(content='好的,这是一个笑话:\\n\\n三岁的女儿说她会造句了,妈妈让她用“年轻”造句,女儿说:“妈妈减肥,一年轻了好几斤”。')"
|
||||||
|
},
|
||||||
|
"execution_count": 28,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat([HumanMessage(content=\"给我讲个笑话\")])"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
124
docs/docs/integrations/llms/volcengine_maas.ipynb
Normal file
124
docs/docs/integrations/llms/volcengine_maas.ipynb
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "404758628c7b20f6",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Volc Engine Maas\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook provides you with a guide on how to get started with Volc Engine's MaaS llm models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "946db204b33c2ef7",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Install the package\n",
|
||||||
|
"!pip install volcengine"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "51e7f967cb78f5b7",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:40:26.897649Z",
|
||||||
|
"start_time": "2023-11-27T10:40:26.552589Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import VolcEngineMaasLLM\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.schema.output_parser import StrOutputParser"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "139667d44689f9e0",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:40:27.938517Z",
|
||||||
|
"start_time": "2023-11-27T10:40:27.861324Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = VolcEngineMaasLLM(volc_engine_maas_ak=\"your ak\", volc_engine_maas_sk=\"your sk\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e84ebc4feedcc739",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"or you can set access_key and secret_key in your environment variables\n",
|
||||||
|
"```bash\n",
|
||||||
|
"export VOLC_ACCESSKEY=YOUR_AK\n",
|
||||||
|
"export VOLC_SECRETKEY=YOUR_SK\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "35da18414ad17aa0",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-11-27T10:41:35.528526Z",
|
||||||
|
"start_time": "2023-11-27T10:41:32.562238Z"
|
||||||
|
},
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "'好的,下面是一个笑话:\\n\\n大学暑假我配了隐形眼镜,回家给爷爷说,我现在配了隐形眼镜。\\n爷爷让我给他看看,于是,我用小镊子夹了一片给爷爷看。\\n爷爷看完便准备出门,边走还边说:“真高级啊,还真是隐形眼镜!”\\n等爷爷出去后我才发现,我刚没夹起来!'"
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain = PromptTemplate.from_template(\"给我讲个笑话\") | llm | StrOutputParser()\n",
|
||||||
|
"chain.invoke({})"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -43,6 +43,7 @@ from langchain.chat_models.openai import ChatOpenAI
|
|||||||
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
||||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||||
from langchain.chat_models.vertexai import ChatVertexAI
|
from langchain.chat_models.vertexai import ChatVertexAI
|
||||||
|
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||||
from langchain.chat_models.yandex import ChatYandexGPT
|
from langchain.chat_models.yandex import ChatYandexGPT
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -73,4 +74,5 @@ __all__ = [
|
|||||||
"ChatBaichuan",
|
"ChatBaichuan",
|
||||||
"ChatHunyuan",
|
"ChatHunyuan",
|
||||||
"GigaChat",
|
"GigaChat",
|
||||||
|
"VolcEngineMaasChat",
|
||||||
]
|
]
|
||||||
|
141
libs/langchain/langchain/chat_models/volcengine_maas.py
Normal file
141
libs/langchain/langchain/chat_models/volcengine_maas.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.llms.volcengine_maas import VolcEngineMaasBase
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
message_dict = {"role": "function", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||||
|
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
||||||
|
return AIMessage(
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||||
|
|
||||||
|
"""volc engine maas hosts a plethora of models.
|
||||||
|
You can utilize these models through this class.
|
||||||
|
|
||||||
|
To use, you should have the ``volcengine`` python package installed.
|
||||||
|
and set access key and secret key by environment variable or direct pass those
|
||||||
|
to this class.
|
||||||
|
access key, secret key are required parameters which you could get help
|
||||||
|
https://www.volcengine.com/docs/6291/65568
|
||||||
|
|
||||||
|
In order to use them, it is necessary to install the 'volcengine' Python package.
|
||||||
|
The access key and secret key must be set either via environment variables or
|
||||||
|
passed directly to this class.
|
||||||
|
access key and secret key are mandatory parameters for which assistance can be
|
||||||
|
sought at https://www.volcengine.com/docs/6291/65568.
|
||||||
|
|
||||||
|
The two methods are as follows:
|
||||||
|
* Environment Variable
|
||||||
|
Set the environment variables 'VOLC_ACCESSKEY' and 'VOLC_SECRETKEY' with your
|
||||||
|
access key and secret key.
|
||||||
|
|
||||||
|
* Pass Directly to Class
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import VolcEngineMaasLLM
|
||||||
|
model = VolcEngineMaasChat(model="skylark-lite-public",
|
||||||
|
volc_engine_maas_ak="your_ak",
|
||||||
|
volc_engine_maas_sk="your_sk")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "volc-engine-maas-chat"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
**{"endpoint": self.endpoint, "model": self.model},
|
||||||
|
**super()._identifying_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_prompt_msg_params(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
model_req = {
|
||||||
|
"model": {
|
||||||
|
"name": self.model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.model_version is not None:
|
||||||
|
model_req["model"]["version"] = self.model_version
|
||||||
|
return {
|
||||||
|
**model_req,
|
||||||
|
"messages": [_convert_message_to_dict(message) for message in messages],
|
||||||
|
"parameters": {**self._default_params, **kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
|
for res in self.client.stream_chat(params):
|
||||||
|
if res:
|
||||||
|
msg = convert_dict_to_message(res)
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(msg.content)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
completion = ""
|
||||||
|
if self.streaming:
|
||||||
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||||
|
completion += chunk.text
|
||||||
|
else:
|
||||||
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
|
res = self.client.chat(params)
|
||||||
|
msg = convert_dict_to_message(res)
|
||||||
|
completion = msg.content
|
||||||
|
|
||||||
|
message = AIMessage(content=completion)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
@ -504,6 +504,12 @@ def _import_yandex_gpt() -> Any:
|
|||||||
return YandexGPT
|
return YandexGPT
|
||||||
|
|
||||||
|
|
||||||
|
def _import_volcengine_maas() -> Any:
|
||||||
|
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
|
||||||
|
|
||||||
|
return VolcEngineMaasLLM
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "AI21":
|
if name == "AI21":
|
||||||
return _import_ai21()
|
return _import_ai21()
|
||||||
@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_xinference()
|
return _import_xinference()
|
||||||
elif name == "YandexGPT":
|
elif name == "YandexGPT":
|
||||||
return _import_yandex_gpt()
|
return _import_yandex_gpt()
|
||||||
|
elif name == "VolcEngineMaasLLM":
|
||||||
|
return _import_volcengine_maas()
|
||||||
elif name == "type_to_cls_dict":
|
elif name == "type_to_cls_dict":
|
||||||
# for backwards compatibility
|
# for backwards compatibility
|
||||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
@ -755,6 +763,7 @@ __all__ = [
|
|||||||
"JavelinAIGateway",
|
"JavelinAIGateway",
|
||||||
"QianfanLLMEndpoint",
|
"QianfanLLMEndpoint",
|
||||||
"YandexGPT",
|
"YandexGPT",
|
||||||
|
"VolcEngineMaasLLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -834,4 +843,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||||
"yandex_gpt": _import_yandex_gpt,
|
"yandex_gpt": _import_yandex_gpt,
|
||||||
|
"VolcEngineMaasLLM": _import_volcengine_maas(),
|
||||||
}
|
}
|
||||||
|
176
libs/langchain/langchain/llms/volcengine_maas.py
Normal file
176
libs/langchain/langchain/llms/volcengine_maas.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.outputs import GenerationChunk
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class VolcEngineMaasBase(BaseModel):
|
||||||
|
"""Base class for VolcEngineMaas models."""
|
||||||
|
|
||||||
|
client: Any
|
||||||
|
|
||||||
|
volc_engine_maas_ak: Optional[str] = None
|
||||||
|
"""access key for volc engine"""
|
||||||
|
volc_engine_maas_sk: Optional[str] = None
|
||||||
|
"""secret key for volc engine"""
|
||||||
|
|
||||||
|
endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com"
|
||||||
|
"""Endpoint of the VolcEngineMaas LLM."""
|
||||||
|
|
||||||
|
region: Optional[str] = "Region"
|
||||||
|
"""Region of the VolcEngineMaas LLM."""
|
||||||
|
|
||||||
|
model: str = "skylark-lite-public"
|
||||||
|
"""Model name. you could check this model details here
|
||||||
|
https://www.volcengine.com/docs/82379/1133187
|
||||||
|
and you could choose other models by change this field"""
|
||||||
|
model_version: Optional[str] = None
|
||||||
|
"""Model version. Only used in moonshot large language model.
|
||||||
|
you could check details here https://www.volcengine.com/docs/82379/1158281"""
|
||||||
|
|
||||||
|
top_p: Optional[float] = 0.8
|
||||||
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
|
|
||||||
|
temperature: Optional[float] = 0.95
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""model special arguments, you could check detail on model page"""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
connect_timeout: Optional[int] = 60
|
||||||
|
"""Timeout for connect to volc engine maas endpoint. Default is 60 seconds."""
|
||||||
|
|
||||||
|
read_timeout: Optional[int] = 60
|
||||||
|
"""Timeout for read response from volc engine maas endpoint.
|
||||||
|
Default is 60 seconds."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY")
|
||||||
|
sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY")
|
||||||
|
endpoint = values["endpoint"]
|
||||||
|
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||||
|
endpoint = values["endpoint"]
|
||||||
|
try:
|
||||||
|
from volcengine.maas import MaasService
|
||||||
|
|
||||||
|
maas = MaasService(
|
||||||
|
endpoint,
|
||||||
|
values["region"],
|
||||||
|
connection_timeout=values["connect_timeout"],
|
||||||
|
socket_timeout=values["read_timeout"],
|
||||||
|
)
|
||||||
|
maas.set_ak(ak)
|
||||||
|
values["volc_engine_maas_ak"] = ak
|
||||||
|
values["volc_engine_maas_sk"] = sk
|
||||||
|
maas.set_sk(sk)
|
||||||
|
values["client"] = maas
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"volcengine package not found, please install it with "
|
||||||
|
"`pip install volcengine`"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling VolcEngineMaas API."""
|
||||||
|
normal_params = {
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {**normal_params, **self.model_kwargs}
|
||||||
|
|
||||||
|
|
||||||
|
class VolcEngineMaasLLM(LLM, VolcEngineMaasBase):
|
||||||
|
"""volc engine maas hosts a plethora of models.
|
||||||
|
You can utilize these models through this class.
|
||||||
|
|
||||||
|
To use, you should have the ``volcengine`` python package installed.
|
||||||
|
and set access key and secret key by environment variable or direct pass those to
|
||||||
|
this class.
|
||||||
|
access key, secret key are required parameters which you could get help
|
||||||
|
https://www.volcengine.com/docs/6291/65568
|
||||||
|
|
||||||
|
In order to use them, it is necessary to install the 'volcengine' Python package.
|
||||||
|
The access key and secret key must be set either via environment variables or
|
||||||
|
passed directly to this class.
|
||||||
|
access key and secret key are mandatory parameters for which assistance can be
|
||||||
|
sought at https://www.volcengine.com/docs/6291/65568.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import VolcEngineMaasLLM
|
||||||
|
model = VolcEngineMaasLLM(model="skylark-lite-public",
|
||||||
|
volc_engine_maas_ak="your_ak",
|
||||||
|
volc_engine_maas_sk="your_sk")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "volc-engine-maas-llm"
|
||||||
|
|
||||||
|
def _convert_prompt_msg_params(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict:
|
||||||
|
model_req = {
|
||||||
|
"model": {
|
||||||
|
"name": self.model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.model_version is not None:
|
||||||
|
model_req["model"]["version"] = self.model_version
|
||||||
|
|
||||||
|
return {
|
||||||
|
**model_req,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"parameters": {**self._default_params, **kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||||
|
response = self.client.chat(params)
|
||||||
|
|
||||||
|
return response.get("choice", {}).get("message", {}).get("content", "")
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||||
|
for res in self.client.stream_chat(params):
|
||||||
|
if res:
|
||||||
|
chunk = GenerationChunk(
|
||||||
|
text=res.get("choice", {}).get("message", {}).get("content", "")
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
@ -0,0 +1,69 @@
|
|||||||
|
"""Test volc engine maas chat model."""
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
HumanMessage,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_call() -> None:
|
||||||
|
"""Test valid chat call to volc engine."""
|
||||||
|
chat = VolcEngineMaasChat()
|
||||||
|
response = chat(messages=[HumanMessage(content="Hello")])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_history() -> None:
|
||||||
|
"""Tests multiple history works."""
|
||||||
|
chat = VolcEngineMaasChat()
|
||||||
|
|
||||||
|
response = chat(
|
||||||
|
messages=[
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
HumanMessage(content="How are you?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream() -> None:
|
||||||
|
"""Test that stream works."""
|
||||||
|
chat = VolcEngineMaasChat(streaming=True)
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
response = chat(
|
||||||
|
messages=[
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
HumanMessage(content="How are you?"),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
callbacks=callback_manager,
|
||||||
|
)
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_messages() -> None:
|
||||||
|
"""Tests multiple messages works."""
|
||||||
|
chat = VolcEngineMaasChat()
|
||||||
|
message = HumanMessage(content="Hi, how are you?")
|
||||||
|
response = chat.generate([[message], [message]])
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
@ -0,0 +1,28 @@
|
|||||||
|
"""Test volc engine maas LLM model."""
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_call() -> None:
|
||||||
|
"""Test valid call to volc engine."""
|
||||||
|
llm = VolcEngineMaasLLM()
|
||||||
|
output = llm("tell me a joke")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate() -> None:
|
||||||
|
"""Test valid call to volc engine."""
|
||||||
|
llm = VolcEngineMaasLLM()
|
||||||
|
output = llm.generate(["tell me a joke"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert isinstance(output.generations, list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream() -> None:
|
||||||
|
"""Test valid call to volc engine."""
|
||||||
|
llm = VolcEngineMaasLLM(streaming=True)
|
||||||
|
output = llm.stream("tell me a joke")
|
||||||
|
assert isinstance(output, Generator)
|
@ -28,6 +28,7 @@ EXPECTED_ALL = [
|
|||||||
"ChatBaichuan",
|
"ChatBaichuan",
|
||||||
"ChatHunyuan",
|
"ChatHunyuan",
|
||||||
"GigaChat",
|
"GigaChat",
|
||||||
|
"VolcEngineMaasChat",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ EXPECT_ALL = [
|
|||||||
"JavelinAIGateway",
|
"JavelinAIGateway",
|
||||||
"QianfanLLMEndpoint",
|
"QianfanLLMEndpoint",
|
||||||
"YandexGPT",
|
"YandexGPT",
|
||||||
|
"VolcEngineMaasLLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user