mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
Add Minimax chat model (#10776)
resolve the merging issues for https://github.com/langchain-ai/langchain/pull/6757 --------- Co-authored-by: 何涛 <taohe@bytedance.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from langchain.chat_models.human import HumanInputChatModel
|
||||
from langchain.chat_models.jinachat import JinaChat
|
||||
from langchain.chat_models.konko import ChatKonko
|
||||
from langchain.chat_models.litellm import ChatLiteLLM
|
||||
from langchain.chat_models.minimax import MiniMaxChat
|
||||
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||
from langchain.chat_models.ollama import ChatOllama
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
@@ -48,6 +49,7 @@ __all__ = [
|
||||
"ChatVertexAI",
|
||||
"JinaChat",
|
||||
"HumanInputChatModel",
|
||||
"MiniMaxChat",
|
||||
"ChatAnyscale",
|
||||
"ChatLiteLLM",
|
||||
"ErnieBotChat",
|
||||
|
93
libs/langchain/langchain/chat_models/minimax.py
Normal file
93
libs/langchain/langchain/chat_models/minimax.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Wrapper around Minimax chat models."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.minimax import MinimaxCommon
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_message(msg_type: str, text: str) -> Dict:
|
||||
return {"sender_type": msg_type, "text": text}
|
||||
|
||||
|
||||
def _parse_chat_history(history: List[BaseMessage]) -> List:
|
||||
"""Parse a sequence of messages into history."""
|
||||
chat_history = []
|
||||
for message in history:
|
||||
if isinstance(message, HumanMessage):
|
||||
chat_history.append(_parse_message("USER", message.content))
|
||||
if isinstance(message, AIMessage):
|
||||
chat_history.append(_parse_message("BOT", message.content))
|
||||
return chat_history
|
||||
|
||||
|
||||
class MiniMaxChat(MinimaxCommon, BaseChatModel):
|
||||
"""Wrapper around Minimax large language models.
|
||||
|
||||
To use, you should have the environment variable ``MINIMAX_GROUP_ID`` and
|
||||
``MINIMAX_API_KEY`` set with your API token, or pass it as a named parameter to
|
||||
the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import MiniMaxChat
|
||||
llm = MiniMaxChat(model_name="abab5-chat")
|
||||
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate next turn in the conversation.
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
if not messages:
|
||||
raise ValueError(
|
||||
"You should provide at least one message to start the chat!"
|
||||
)
|
||||
history = _parse_chat_history(messages)
|
||||
payload = self._default_params
|
||||
payload["messages"] = history
|
||||
text = self._client.post(payload)
|
||||
|
||||
# This is required since the stop are not enforced by the model parameters
|
||||
return text if stop is None else enforce_stop_tokens(text, stop)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError(
|
||||
"""Minimax AI doesn't support async requests at the moment."""
|
||||
)
|
@@ -15,7 +15,8 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr, root_validator
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,7 +30,7 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
api_key: str
|
||||
api_url: str
|
||||
|
||||
@root_validator(pre=True)
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
host = values["host"]
|
||||
@@ -52,19 +53,8 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
return response.json()["reply"]
|
||||
|
||||
|
||||
class Minimax(LLM):
|
||||
"""Wrapper around Minimax large language models.
|
||||
To use, you should have the environment variable
|
||||
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
|
||||
or pass them as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms.minimax import Minimax
|
||||
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
|
||||
minimax_group_id="my-group-id")
|
||||
"""
|
||||
|
||||
_client: _MinimaxEndpointClient = PrivateAttr()
|
||||
class MinimaxCommon(BaseModel):
|
||||
_client: _MinimaxEndpointClient
|
||||
model: str = "abab5.5-chat"
|
||||
"""Model name to use."""
|
||||
max_tokens: int = 256
|
||||
@@ -79,11 +69,6 @@ class Minimax(LLM):
|
||||
minimax_group_id: Optional[str] = None
|
||||
minimax_api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@@ -131,6 +116,19 @@ class Minimax(LLM):
|
||||
group_id=self.minimax_group_id,
|
||||
)
|
||||
|
||||
|
||||
class Minimax(MinimaxCommon, LLM):
|
||||
"""Wrapper around Minimax large language models.
|
||||
To use, you should have the environment variable
|
||||
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
|
||||
or pass them as a named parameter to the constructor.
|
||||
Example:
|
||||
. code-block:: python
|
||||
from langchain.llms.minimax import Minimax
|
||||
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
|
||||
minimax_group_id="my-group-id")
|
||||
"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -150,6 +148,10 @@ class Minimax(LLM):
|
||||
request = self._default_params
|
||||
request["messages"] = [{"sender_type": "USER", "text": prompt}]
|
||||
request.update(kwargs)
|
||||
response = self._client.post(request)
|
||||
text = self._client.post(request)
|
||||
if stop is not None:
|
||||
# This is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return response
|
||||
return text
|
||||
|
@@ -7,3 +7,16 @@ def test_minimax_call() -> None:
|
||||
llm = Minimax(max_tokens=10)
|
||||
output = llm("Hello world!")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_minimax_call_successful() -> None:
|
||||
"""Test valid call to minimax."""
|
||||
llm = Minimax()
|
||||
output = llm(
|
||||
"A chain is a serial assembly of connected pieces, called links, \
|
||||
typically made of metal, with an overall character similar to that\
|
||||
of a rope in that it is flexible and curved in compression but \
|
||||
linear, rigid, and load-bearing in tension. A chain may consist\
|
||||
of two or more links."
|
||||
)
|
||||
assert isinstance(output, str)
|
||||
|
Reference in New Issue
Block a user