mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
Add Minimax chat model (#10776)
resolve the merging issues for https://github.com/langchain-ai/langchain/pull/6757 --------- Co-authored-by: 何涛 <taohe@bytedance.com>
This commit is contained in:
parent
c656a6b966
commit
f505320a73
70
docs/extras/integrations/chat/minimax.ipynb
Normal file
70
docs/extras/integrations/chat/minimax.ipynb
Normal file
@ -0,0 +1,70 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MiniMax\n",
|
||||
"\n",
|
||||
"[Minimax](https://api.minimax.chat) is a Chinese startup that provides LLM service for companies and individuals.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with MiniMax Inference for Chat."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"MINIMAX_GROUP_ID\"] = \"MINIMAX_GROUP_ID\"\n",
|
||||
"os.environ[\"MINIMAX_API_KEY\"] = \"MINIMAX_API_KEY\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import MiniMaxChat\n",
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = MiniMaxChat()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat(\n",
|
||||
" [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -94,7 +94,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import Minimax\n",
|
||||
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
|
@ -108,7 +108,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import Modal\n",
|
||||
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -17,6 +17,14 @@ See a [usage example](/docs/modules/model_io/models/llms/integrations/minimax.ht
|
||||
from langchain.llms import Minimax
|
||||
```
|
||||
|
||||
## Chat Models
|
||||
|
||||
See a [usage example](/docs/modules/model_io/models/chat/integrations/minimax.html)
|
||||
|
||||
```python
|
||||
from langchain.chat_models import MiniMaxChat
|
||||
```
|
||||
|
||||
## Text Embedding Model
|
||||
|
||||
There exists a Minimax Embedding model, which you can access with
|
||||
|
@ -29,6 +29,7 @@ from langchain.chat_models.human import HumanInputChatModel
|
||||
from langchain.chat_models.jinachat import JinaChat
|
||||
from langchain.chat_models.konko import ChatKonko
|
||||
from langchain.chat_models.litellm import ChatLiteLLM
|
||||
from langchain.chat_models.minimax import MiniMaxChat
|
||||
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||
from langchain.chat_models.ollama import ChatOllama
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
@ -48,6 +49,7 @@ __all__ = [
|
||||
"ChatVertexAI",
|
||||
"JinaChat",
|
||||
"HumanInputChatModel",
|
||||
"MiniMaxChat",
|
||||
"ChatAnyscale",
|
||||
"ChatLiteLLM",
|
||||
"ErnieBotChat",
|
||||
|
93
libs/langchain/langchain/chat_models/minimax.py
Normal file
93
libs/langchain/langchain/chat_models/minimax.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Wrapper around Minimax chat models."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.minimax import MinimaxCommon
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_message(msg_type: str, text: str) -> Dict:
|
||||
return {"sender_type": msg_type, "text": text}
|
||||
|
||||
|
||||
def _parse_chat_history(history: List[BaseMessage]) -> List:
|
||||
"""Parse a sequence of messages into history."""
|
||||
chat_history = []
|
||||
for message in history:
|
||||
if isinstance(message, HumanMessage):
|
||||
chat_history.append(_parse_message("USER", message.content))
|
||||
if isinstance(message, AIMessage):
|
||||
chat_history.append(_parse_message("BOT", message.content))
|
||||
return chat_history
|
||||
|
||||
|
||||
class MiniMaxChat(MinimaxCommon, BaseChatModel):
|
||||
"""Wrapper around Minimax large language models.
|
||||
|
||||
To use, you should have the environment variable ``MINIMAX_GROUP_ID`` and
|
||||
``MINIMAX_API_KEY`` set with your API token, or pass it as a named parameter to
|
||||
the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import MiniMaxChat
|
||||
llm = MiniMaxChat(model_name="abab5-chat")
|
||||
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate next turn in the conversation.
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
if not messages:
|
||||
raise ValueError(
|
||||
"You should provide at least one message to start the chat!"
|
||||
)
|
||||
history = _parse_chat_history(messages)
|
||||
payload = self._default_params
|
||||
payload["messages"] = history
|
||||
text = self._client.post(payload)
|
||||
|
||||
# This is required since the stop are not enforced by the model parameters
|
||||
return text if stop is None else enforce_stop_tokens(text, stop)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError(
|
||||
"""Minimax AI doesn't support async requests at the moment."""
|
||||
)
|
@ -15,7 +15,8 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr, root_validator
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -29,7 +30,7 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
api_key: str
|
||||
api_url: str
|
||||
|
||||
@root_validator(pre=True)
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
host = values["host"]
|
||||
@ -52,19 +53,8 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
return response.json()["reply"]
|
||||
|
||||
|
||||
class Minimax(LLM):
|
||||
"""Wrapper around Minimax large language models.
|
||||
To use, you should have the environment variable
|
||||
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
|
||||
or pass them as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms.minimax import Minimax
|
||||
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
|
||||
minimax_group_id="my-group-id")
|
||||
"""
|
||||
|
||||
_client: _MinimaxEndpointClient = PrivateAttr()
|
||||
class MinimaxCommon(BaseModel):
|
||||
_client: _MinimaxEndpointClient
|
||||
model: str = "abab5.5-chat"
|
||||
"""Model name to use."""
|
||||
max_tokens: int = 256
|
||||
@ -79,11 +69,6 @@ class Minimax(LLM):
|
||||
minimax_group_id: Optional[str] = None
|
||||
minimax_api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -131,6 +116,19 @@ class Minimax(LLM):
|
||||
group_id=self.minimax_group_id,
|
||||
)
|
||||
|
||||
|
||||
class Minimax(MinimaxCommon, LLM):
|
||||
"""Wrapper around Minimax large language models.
|
||||
To use, you should have the environment variable
|
||||
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
|
||||
or pass them as a named parameter to the constructor.
|
||||
Example:
|
||||
. code-block:: python
|
||||
from langchain.llms.minimax import Minimax
|
||||
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
|
||||
minimax_group_id="my-group-id")
|
||||
"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -150,6 +148,10 @@ class Minimax(LLM):
|
||||
request = self._default_params
|
||||
request["messages"] = [{"sender_type": "USER", "text": prompt}]
|
||||
request.update(kwargs)
|
||||
response = self._client.post(request)
|
||||
text = self._client.post(request)
|
||||
if stop is not None:
|
||||
# This is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return response
|
||||
return text
|
||||
|
@ -7,3 +7,16 @@ def test_minimax_call() -> None:
|
||||
llm = Minimax(max_tokens=10)
|
||||
output = llm("Hello world!")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_minimax_call_successful() -> None:
|
||||
"""Test valid call to minimax."""
|
||||
llm = Minimax()
|
||||
output = llm(
|
||||
"A chain is a serial assembly of connected pieces, called links, \
|
||||
typically made of metal, with an overall character similar to that\
|
||||
of a rope in that it is flexible and curved in compression but \
|
||||
linear, rigid, and load-bearing in tension. A chain may consist\
|
||||
of two or more links."
|
||||
)
|
||||
assert isinstance(output, str)
|
||||
|
Loading…
Reference in New Issue
Block a user