Add Minimax llm model to langchain (#7645)

- Description: Minimax is a great AI startup from China, recently they
released their latest model and chat API, and the API is widely-spread
in China. As a result, I'd like to add the Minimax llm model to
Langchain.
- Tag maintainer: @hwchase17, @baskaryan

---------

Co-authored-by: the <tao.he@hulu.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
HeTaoPKU 2023-07-28 13:53:23 +08:00 committed by GitHub
parent 0ad2d5f27a
commit d5884017a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 368 additions and 0 deletions

View File

@ -0,0 +1,176 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Minimax\n",
"\n",
"[Minimax](https://api.minimax.chat) is a Chinese startup that provides natural language processing models for companies and individuals.\n",
"\n",
"This example demonstrates using Langchain to interact with Minimax."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"To run this notebook, you'll need a [Minimax account](https://api.minimax.chat), an [API key](https://api.minimax.chat/user-center/basic-information/interface-key), and a [Group ID](https://api.minimax.chat/user-center/basic-information)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Single model call"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Minimax"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# Load the model\n",
"minimax = Minimax(minimax_api_key=\"YOUR_API_KEY\", minimax_group_id=\"YOUR_GROUP_ID\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"# Prompt the model\n",
"minimax(\"What is the difference between panda and bear?\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Chained model calls"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# get api_key and group_id: https://api.minimax.chat/user-center/basic-information\n",
"# We need `MINIMAX_API_KEY` and `MINIMAX_GROUP_ID`\n",
"\n",
"import os\n",
"\n",
"os.environ[\"MINIMAX_API_KEY\"] = \"YOUR_API_KEY\"\n",
"os.environ[\"MINIMAX_GROUP_ID\"] = \"YOUR_GROUP_ID\""
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from langchain.llms import Minimax\n",
"from langchain import PromptTemplate, LLMChain"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"llm = Minimax()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"question = \"What NBA team won the Championship in the year Jay Zhou was born?\"\n",
"\n",
"llm_chain.run(question)"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,25 @@
# Minimax
>[Minimax](https://api.minimax.chat) is a Chinese startup that provides natural language processing models
> for companies and individuals.
## Installation and Setup
Get a [Minimax api key](https://api.minimax.chat/user-center/basic-information/interface-key) and set it as an environment variable (`MINIMAX_API_KEY`)
Get a [Minimax group id](https://api.minimax.chat/user-center/basic-information) and set it as an environment variable (`MINIMAX_GROUP_ID`)
## LLM
There exists a Minimax LLM wrapper, which you can access with
See a [usage example](/docs/modules/model_io/models/llms/integrations/minimax.html).
```python
from langchain.llms import Minimax
```
## Text Embedding Model
There exists a Minimax Embedding model, which you can access with
```python
from langchain.embeddings import MiniMaxEmbeddings
```

View File

@ -33,6 +33,7 @@ from langchain.llms.human import HumanInputLLM
from langchain.llms.koboldai import KoboldApiLLM from langchain.llms.koboldai import KoboldApiLLM
from langchain.llms.llamacpp import LlamaCpp from langchain.llms.llamacpp import LlamaCpp
from langchain.llms.manifest import ManifestWrapper from langchain.llms.manifest import ManifestWrapper
from langchain.llms.minimax import Minimax
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
from langchain.llms.modal import Modal from langchain.llms.modal import Modal
from langchain.llms.mosaicml import MosaicML from langchain.llms.mosaicml import MosaicML
@ -92,6 +93,7 @@ __all__ = [
"LlamaCpp", "LlamaCpp",
"TextGen", "TextGen",
"ManifestWrapper", "ManifestWrapper",
"Minimax",
"MlflowAIGateway", "MlflowAIGateway",
"Modal", "Modal",
"MosaicML", "MosaicML",
@ -152,6 +154,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"koboldai": KoboldApiLLM, "koboldai": KoboldApiLLM,
"llamacpp": LlamaCpp, "llamacpp": LlamaCpp,
"textgen": TextGen, "textgen": TextGen,
"minimax": Minimax,
"mlflow-ai-gateway": MlflowAIGateway, "mlflow-ai-gateway": MlflowAIGateway,
"modal": Modal, "modal": Modal,
"mosaic": MosaicML, "mosaic": MosaicML,

View File

@ -0,0 +1,155 @@
"""Wrapper around Minimax APIs."""
from __future__ import annotations
import logging
from typing import (
Any,
Dict,
List,
Optional,
)
import requests
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class _MinimaxEndpointClient(BaseModel):
"""An API client that talks to a Minimax llm endpoint."""
host: str
group_id: str
api_key: str
api_url: str
@root_validator(pre=True)
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "api_url" not in values:
host = values["host"]
group_id = values["group_id"]
api_url = f"{host}/v1/text/chatcompletion?GroupId={group_id}"
values["api_url"] = api_url
return values
def post(self, request: Any) -> Any:
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(self.api_url, headers=headers, json=request)
# TODO: error handling and automatic retries
if not response.ok:
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
if response.json()["base_resp"]["status_code"] > 0:
raise ValueError(
f"API {response.json()['base_resp']['status_code']}"
f" error: {response.json()['base_resp']['status_msg']}"
)
return response.json()["reply"]
class Minimax(LLM):
"""Wrapper around Minimax large language models.
To use, you should have the environment variable
``MINIMAX_API_KEY`` and ``MINIMAX_GROUP_ID`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.minimax import Minimax
minimax = Minimax(model="<model_name>", minimax_api_key="my-api-key",
minimax_group_id="my-group-id")
"""
_client: _MinimaxEndpointClient = PrivateAttr()
model: str = "abab5.5-chat"
"""Model name to use."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: float = 0.7
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.95
"""Total probability mass of tokens to consider at each step."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
minimax_api_host: Optional[str] = None
minimax_group_id: Optional[str] = None
minimax_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["minimax_api_key"] = get_from_dict_or_env(
values, "minimax_api_key", "MINIMAX_API_KEY"
)
values["minimax_group_id"] = get_from_dict_or_env(
values, "minimax_group_id", "MINIMAX_GROUP_ID"
)
# Get custom api url from environment.
values["minimax_api_host"] = get_from_dict_or_env(
values,
"minimax_api_host",
"MINIMAX_API_HOST",
default="https://api.minimax.chat",
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"tokens_to_generate": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
**self.model_kwargs,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "minimax"
def __init__(self, **data: Any):
super().__init__(**data)
self._client = _MinimaxEndpointClient(
host=self.minimax_api_host,
api_key=self.minimax_api_key,
group_id=self.minimax_group_id,
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Minimax's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = minimax("Tell me a joke.")
"""
request = self._default_params
request["messages"] = [{"sender_type": "USER", "text": prompt}]
request.update(kwargs)
response = self._client.post(request)
return response

View File

@ -0,0 +1,9 @@
"""Test Minimax API wrapper."""
from langchain.llms.minimax import Minimax
def test_minimax_call() -> None:
"""Test valid call to minimax."""
llm = Minimax(max_tokens=10)
output = llm("Hello world!")
assert isinstance(output, str)