mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
langchain[minor]: add volcengine endpoint as LLM (#13942)
- **Description:** Volc Engine MaaS serves as an enterprise-grade, large-model service platform designed for developers. You can visit its homepage at https://www.volcengine.com/docs/82379/1099455 for details. This change will facilitate developers to integrate quickly with the platform. - **Issue:** None - **Dependencies:** volcengine - **Tag maintainer:** @baskaryan - **Twitter handle:** @he1v3tica --------- Co-authored-by: lvzhong <lvzhong@bytedance.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
from langchain.chat_models.vertexai import ChatVertexAI
|
||||
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||
from langchain.chat_models.yandex import ChatYandexGPT
|
||||
|
||||
__all__ = [
|
||||
@@ -73,4 +74,5 @@ __all__ = [
|
||||
"ChatBaichuan",
|
||||
"ChatHunyuan",
|
||||
"GigaChat",
|
||||
"VolcEngineMaasChat",
|
||||
]
|
||||
|
141
libs/langchain/langchain/chat_models/volcengine_maas.py
Normal file
141
libs/langchain/langchain/chat_models/volcengine_maas.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.volcengine_maas import VolcEngineMaasBase
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
||||
return AIMessage(
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||
|
||||
"""volc engine maas hosts a plethora of models.
|
||||
You can utilize these models through this class.
|
||||
|
||||
To use, you should have the ``volcengine`` python package installed.
|
||||
and set access key and secret key by environment variable or direct pass those
|
||||
to this class.
|
||||
access key, secret key are required parameters which you could get help
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
|
||||
In order to use them, it is necessary to install the 'volcengine' Python package.
|
||||
The access key and secret key must be set either via environment variables or
|
||||
passed directly to this class.
|
||||
access key and secret key are mandatory parameters for which assistance can be
|
||||
sought at https://www.volcengine.com/docs/6291/65568.
|
||||
|
||||
The two methods are as follows:
|
||||
* Environment Variable
|
||||
Set the environment variables 'VOLC_ACCESSKEY' and 'VOLC_SECRETKEY' with your
|
||||
access key and secret key.
|
||||
|
||||
* Pass Directly to Class
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import VolcEngineMaasLLM
|
||||
model = VolcEngineMaasChat(model="skylark-lite-public",
|
||||
volc_engine_maas_ak="your_ak",
|
||||
volc_engine_maas_sk="your_sk")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "volc-engine-maas-chat"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
model_req = {
|
||||
"model": {
|
||||
"name": self.model,
|
||||
}
|
||||
}
|
||||
if self.model_version is not None:
|
||||
model_req["model"]["version"] = self.model_version
|
||||
return {
|
||||
**model_req,
|
||||
"messages": [_convert_message_to_dict(message) for message in messages],
|
||||
"parameters": {**self._default_params, **kwargs},
|
||||
}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
for res in self.client.stream_chat(params):
|
||||
if res:
|
||||
msg = convert_dict_to_message(res)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(msg.content)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
completion = ""
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
res = self.client.chat(params)
|
||||
msg = convert_dict_to_message(res)
|
||||
completion = msg.content
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
@@ -504,6 +504,12 @@ def _import_yandex_gpt() -> Any:
|
||||
return YandexGPT
|
||||
|
||||
|
||||
def _import_volcengine_maas() -> Any:
|
||||
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
|
||||
return VolcEngineMaasLLM
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "AI21":
|
||||
return _import_ai21()
|
||||
@@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_xinference()
|
||||
elif name == "YandexGPT":
|
||||
return _import_yandex_gpt()
|
||||
elif name == "VolcEngineMaasLLM":
|
||||
return _import_volcengine_maas()
|
||||
elif name == "type_to_cls_dict":
|
||||
# for backwards compatibility
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
@@ -755,6 +763,7 @@ __all__ = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"VolcEngineMaasLLM",
|
||||
]
|
||||
|
||||
|
||||
@@ -834,4 +843,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||
"yandex_gpt": _import_yandex_gpt,
|
||||
"VolcEngineMaasLLM": _import_volcengine_maas(),
|
||||
}
|
||||
|
176
libs/langchain/langchain/llms/volcengine_maas.py
Normal file
176
libs/langchain/langchain/llms/volcengine_maas.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class VolcEngineMaasBase(BaseModel):
|
||||
"""Base class for VolcEngineMaas models."""
|
||||
|
||||
client: Any
|
||||
|
||||
volc_engine_maas_ak: Optional[str] = None
|
||||
"""access key for volc engine"""
|
||||
volc_engine_maas_sk: Optional[str] = None
|
||||
"""secret key for volc engine"""
|
||||
|
||||
endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com"
|
||||
"""Endpoint of the VolcEngineMaas LLM."""
|
||||
|
||||
region: Optional[str] = "Region"
|
||||
"""Region of the VolcEngineMaas LLM."""
|
||||
|
||||
model: str = "skylark-lite-public"
|
||||
"""Model name. you could check this model details here
|
||||
https://www.volcengine.com/docs/82379/1133187
|
||||
and you could choose other models by change this field"""
|
||||
model_version: Optional[str] = None
|
||||
"""Model version. Only used in moonshot large language model.
|
||||
you could check details here https://www.volcengine.com/docs/82379/1158281"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
temperature: Optional[float] = 0.95
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""model special arguments, you could check detail on model page"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
connect_timeout: Optional[int] = 60
|
||||
"""Timeout for connect to volc engine maas endpoint. Default is 60 seconds."""
|
||||
|
||||
read_timeout: Optional[int] = 60
|
||||
"""Timeout for read response from volc engine maas endpoint.
|
||||
Default is 60 seconds."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY")
|
||||
sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY")
|
||||
endpoint = values["endpoint"]
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
endpoint = values["endpoint"]
|
||||
try:
|
||||
from volcengine.maas import MaasService
|
||||
|
||||
maas = MaasService(
|
||||
endpoint,
|
||||
values["region"],
|
||||
connection_timeout=values["connect_timeout"],
|
||||
socket_timeout=values["read_timeout"],
|
||||
)
|
||||
maas.set_ak(ak)
|
||||
values["volc_engine_maas_ak"] = ak
|
||||
values["volc_engine_maas_sk"] = sk
|
||||
maas.set_sk(sk)
|
||||
values["client"] = maas
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"volcengine package not found, please install it with "
|
||||
"`pip install volcengine`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling VolcEngineMaas API."""
|
||||
normal_params = {
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
|
||||
class VolcEngineMaasLLM(LLM, VolcEngineMaasBase):
|
||||
"""volc engine maas hosts a plethora of models.
|
||||
You can utilize these models through this class.
|
||||
|
||||
To use, you should have the ``volcengine`` python package installed.
|
||||
and set access key and secret key by environment variable or direct pass those to
|
||||
this class.
|
||||
access key, secret key are required parameters which you could get help
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
|
||||
In order to use them, it is necessary to install the 'volcengine' Python package.
|
||||
The access key and secret key must be set either via environment variables or
|
||||
passed directly to this class.
|
||||
access key and secret key are mandatory parameters for which assistance can be
|
||||
sought at https://www.volcengine.com/docs/6291/65568.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import VolcEngineMaasLLM
|
||||
model = VolcEngineMaasLLM(model="skylark-lite-public",
|
||||
volc_engine_maas_ak="your_ak",
|
||||
volc_engine_maas_sk="your_sk")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "volc-engine-maas-llm"
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
model_req = {
|
||||
"model": {
|
||||
"name": self.model,
|
||||
}
|
||||
}
|
||||
if self.model_version is not None:
|
||||
model_req["model"]["version"] = self.model_version
|
||||
|
||||
return {
|
||||
**model_req,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"parameters": {**self._default_params, **kwargs},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
response = self.client.chat(params)
|
||||
|
||||
return response.get("choice", {}).get("message", {}).get("content", "")
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
for res in self.client.stream_chat(params):
|
||||
if res:
|
||||
chunk = GenerationChunk(
|
||||
text=res.get("choice", {}).get("message", {}).get("content", "")
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
@@ -0,0 +1,69 @@
|
||||
"""Test volc engine maas chat model."""
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test valid chat call to volc engine."""
|
||||
chat = VolcEngineMaasChat()
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = VolcEngineMaasChat()
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = VolcEngineMaasChat(streaming=True)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = VolcEngineMaasChat()
|
||||
message = HumanMessage(content="Hi, how are you?")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@@ -0,0 +1,28 @@
|
||||
"""Test volc engine maas LLM model."""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from langchain.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM()
|
||||
output = llm("tell me a joke")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_generate() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM()
|
||||
output = llm.generate(["tell me a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
|
||||
|
||||
def test_generate_stream() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM(streaming=True)
|
||||
output = llm.stream("tell me a joke")
|
||||
assert isinstance(output, Generator)
|
@@ -28,6 +28,7 @@ EXPECTED_ALL = [
|
||||
"ChatBaichuan",
|
||||
"ChatHunyuan",
|
||||
"GigaChat",
|
||||
"VolcEngineMaasChat",
|
||||
]
|
||||
|
||||
|
||||
|
@@ -81,6 +81,7 @@ EXPECT_ALL = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"VolcEngineMaasLLM",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user