mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
community[minor]: Add support for MLX models (chat & llm) (#18152)
**Description:** This PR adds support for MLX models both chat (i.e., instruct) and llm (i.e., pretrained) types/ **Dependencies:** mlx, mlx_lm, transformers **Twitter handle:** @Prince_Canuma --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
6baeaf4802
commit
1f9f4d8742
217
docs/docs/integrations/chat/mlx.ipynb
Normal file
217
docs/docs/integrations/chat/mlx.ipynb
Normal file
@ -0,0 +1,217 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MLX\n",
|
||||
"\n",
|
||||
"This notebook shows how to get started using `MLX` LLM's as chat models.\n",
|
||||
"\n",
|
||||
"In particular, we will:\n",
|
||||
"1. Utilize the [MLXPipeline](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/mlx_pipelines.py), \n",
|
||||
"2. Utilize the `ChatMLX` class to enable any of these LLMs to interface with LangChain's [Chat Messages](https://python.langchain.com/docs/modules/model_io/chat/#messages) abstraction.\n",
|
||||
"3. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet mlx-lm transformers huggingface_hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Instantiate an LLM\n",
|
||||
"\n",
|
||||
"There are three LLM options to choose from."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.mlx_pipeline import MLXPipeline\n",
|
||||
"\n",
|
||||
"llm = MLXPipeline.from_model_id(\n",
|
||||
" \"mlx-community/quantized-gemma-2b-it\",\n",
|
||||
" pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Instantiate the `ChatMLX` to apply chat templates"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Instantiate the chat model and some messages to pass."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import (\n",
|
||||
" HumanMessage,\n",
|
||||
")\n",
|
||||
"from langchain_community.chat_models.mlx import ChatMLX\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"What happens when an unstoppable force meets an immovable object?\"\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"chat_model = ChatMLX(llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Inspect how the chat messages are formatted for the LLM call."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_model._to_chat_prompt(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call the model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = chat_model.invoke(messages)\n",
|
||||
"print(res.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Take it for a spin as an agent!\n",
|
||||
"\n",
|
||||
"Here we'll test out `gemma-2b-it` as a zero-shot `ReAct` Agent. The example below is taken from [here](https://python.langchain.com/docs/modules/agents/agent_types/react#using-chat-models).\n",
|
||||
"\n",
|
||||
"> Note: To run this section, you'll need to have a [SerpAPI Token](https://serpapi.com/) saved as an environment variable: `SERPAPI_API_KEY`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import hub\n",
|
||||
"from langchain.agents import AgentExecutor, load_tools\n",
|
||||
"from langchain.agents.format_scratchpad import format_log_to_str\n",
|
||||
"from langchain.agents.output_parsers import (\n",
|
||||
" ReActJsonSingleInputOutputParser,\n",
|
||||
")\n",
|
||||
"from langchain.tools.render import render_text_description\n",
|
||||
"from langchain_community.utilities import SerpAPIWrapper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Configure the agent with a `react-json` style prompt and access to a search engine and calculator."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# setup tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
|
||||
"\n",
|
||||
"# setup ReAct style prompt\n",
|
||||
"prompt = hub.pull(\"hwchase17/react-json\")\n",
|
||||
"prompt = prompt.partial(\n",
|
||||
" tools=render_text_description(tools),\n",
|
||||
" tool_names=\", \".join([t.name for t in tools]),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# define the agent\n",
|
||||
"chat_model_with_stop = chat_model.bind(stop=[\"\\nObservation\"])\n",
|
||||
"agent = (\n",
|
||||
" {\n",
|
||||
" \"input\": lambda x: x[\"input\"],\n",
|
||||
" \"agent_scratchpad\": lambda x: format_log_to_str(x[\"intermediate_steps\"]),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | chat_model_with_stop\n",
|
||||
" | ReActJsonSingleInputOutputParser()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# instantiate AgentExecutor\n",
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_executor.invoke(\n",
|
||||
" {\n",
|
||||
" \"input\": \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
142
docs/docs/integrations/llms/mlx_pipelines.ipynb
Normal file
142
docs/docs/integrations/llms/mlx_pipelines.ipynb
Normal file
@ -0,0 +1,142 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "959300d4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MLX Local Pipelines\n",
|
||||
"\n",
|
||||
"MLX models can be run locally through the `MLXPipeline` class.\n",
|
||||
"\n",
|
||||
"The [MLX Community](https://huggingface.co/mlx-community) hosts over 150 models, all open source and publicly available on Hugging Face Model Hub a online platform where people can easily collaborate and build ML together.\n",
|
||||
"\n",
|
||||
"These can be called from LangChain either through this local pipeline wrapper or by calling their hosted inference endpoints through the MlXPipeline class. For more information on mlx, see the [examples repo](https://github.com/ml-explore/mlx-examples/tree/main/llms) notebook."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c1b8450-5eaf-4d34-8341-2d785448a1ff",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"To use, you should have the ``mlx-lm`` python [package installed](https://pypi.org/project/mlx-lm/), as well as [transformers](https://pypi.org/project/transformers/). You can also install `huggingface_hub`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d772b637-de00-4663-bd77-9bc96d798db2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet mlx-lm transformers huggingface_hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91ad075f-71d5-4bc8-ab91-cc0ad5ef16bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Model Loading\n",
|
||||
"\n",
|
||||
"Models can be loaded by specifying the model parameters using the `from_model_id` method."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "165ae236-962a-4763-8052-c4836d78a5d2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.mlx_pipeline import MLXPipeline\n",
|
||||
"\n",
|
||||
"pipe = MLXPipeline.from_model_id(\n",
|
||||
" \"mlx-community/quantized-gemma-2b-it\",\n",
|
||||
" pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "00104b27-0c15-4a97-b198-4512337ee211",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"They can also be loaded by passing in an existing `transformers` pipeline directly"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f426a4f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline\n",
|
||||
"from mlx_lm import load\n",
|
||||
"\n",
|
||||
"model, tokenizer = load(\"mlx-community/quantized-gemma-2b-it\")\n",
|
||||
"pipe = MLXPipeline(model=model, tokenizer=tokenizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "60e7ba8d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create Chain\n",
|
||||
"\n",
|
||||
"With the model loaded into memory, you can compose it with a prompt to\n",
|
||||
"form a chain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3acf0069",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"prompt = PromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"chain = prompt | pipe\n",
|
||||
"\n",
|
||||
"question = \"What is electroencephalography?\"\n",
|
||||
"\n",
|
||||
"print(chain.invoke({\"question\": question}))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -41,6 +41,7 @@ _module_lookup = {
|
||||
"ChatLiteLLM": "langchain_community.chat_models.litellm",
|
||||
"ChatLiteLLMRouter": "langchain_community.chat_models.litellm_router",
|
||||
"ChatMLflowAIGateway": "langchain_community.chat_models.mlflow_ai_gateway",
|
||||
"ChatMLX": "langchain_community.chat_models.mlx",
|
||||
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
||||
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||
|
196
libs/community/langchain_community/chat_models/mlx.py
Normal file
196
libs/community/langchain_community/chat_models/mlx.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""MLX Chat Wrapper."""
|
||||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
from langchain_community.llms.mlx_pipeline import MLXPipeline
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
||||
|
||||
|
||||
class ChatMLX(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using MLX LLM's as ChatModels.
|
||||
|
||||
Works with `MLXPipeline` LLM.
|
||||
|
||||
To use, you should have the ``mlx-lm`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import chatMLX
|
||||
from langchain_community.llms import MLXPipeline
|
||||
|
||||
llm = MLXPipeline.from_model_id(
|
||||
model_id="mlx-community/quantized-gemma-2b-it",
|
||||
)
|
||||
chat = chatMLX(llm=llm)
|
||||
|
||||
"""
|
||||
|
||||
llm: MLXPipeline
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
tokenizer: Any = None
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.tokenizer = self.llm.tokenizer
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
tokenize: bool = False,
|
||||
return_tensors: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("At least one HumanMessage must be provided!")
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("Last message must be a HumanMessage!")
|
||||
|
||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages_dicts,
|
||||
tokenize=tokenize,
|
||||
add_generation_prompt=True,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
||||
"""Convert LangChain message to ChatML format."""
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
|
||||
return {"role": role, "content": message.content}
|
||||
|
||||
@staticmethod
|
||||
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
|
||||
chat_generations = []
|
||||
|
||||
for g in llm_result.generations[0]:
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=g.text), generation_info=g.generation_info
|
||||
)
|
||||
chat_generations.append(chat_generation)
|
||||
|
||||
return ChatResult(
|
||||
generations=chat_generations, llm_output=llm_result.llm_output
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mlx-chat-wrapper"
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import generate_step
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import generate_step
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import mlx_lm python package. "
|
||||
"Please install it with `pip install mlx_lm`."
|
||||
)
|
||||
model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs)
|
||||
temp: float = model_kwargs.get("temp", 0.0)
|
||||
max_new_tokens: int = model_kwargs.get("max_tokens", 100)
|
||||
repetition_penalty: Optional[float] = model_kwargs.get(
|
||||
"repetition_penalty", None
|
||||
)
|
||||
repetition_context_size: Optional[int] = model_kwargs.get(
|
||||
"repetition_context_size", None
|
||||
)
|
||||
|
||||
llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")
|
||||
|
||||
prompt_tokens = mx.array(llm_input[0])
|
||||
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
for (token, prob), n in zip(
|
||||
generate_step(
|
||||
prompt_tokens,
|
||||
self.llm.model,
|
||||
temp,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
),
|
||||
range(max_new_tokens),
|
||||
):
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
text = self.tokenizer.decode(token.item())
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(text, chunk=chunk)
|
||||
|
||||
# break if stop sequence found
|
||||
if token == eos_token_id or (stop is not None and text in stop):
|
||||
break
|
@ -356,6 +356,12 @@ def _import_mlflow_ai_gateway() -> Type[BaseLLM]:
|
||||
return MlflowAIGateway
|
||||
|
||||
|
||||
def _import_mlx_pipeline() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.mlx_pipeline import MLXPipeline
|
||||
|
||||
return MLXPipeline
|
||||
|
||||
|
||||
def _import_modal() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.modal import Modal
|
||||
|
||||
@ -737,6 +743,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_mlflow()
|
||||
elif name == "MlflowAIGateway":
|
||||
return _import_mlflow_ai_gateway()
|
||||
elif name == "MLXPipeline":
|
||||
return _import_mlx_pipeline()
|
||||
elif name == "Modal":
|
||||
return _import_modal()
|
||||
elif name == "MosaicML":
|
||||
@ -887,6 +895,7 @@ __all__ = [
|
||||
"Minimax",
|
||||
"Mlflow",
|
||||
"MlflowAIGateway",
|
||||
"MLXPipeline",
|
||||
"Modal",
|
||||
"MosaicML",
|
||||
"NIBittensorLLM",
|
||||
@ -985,6 +994,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"mlflow": _import_mlflow,
|
||||
"mlflow-chat": _import_mlflow_chat, # deprecated / only for back compat
|
||||
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
|
||||
"mlx_pipeline": _import_mlx_pipeline,
|
||||
"modal": _import_modal,
|
||||
"mosaic": _import_mosaicml,
|
||||
"nebula": _import_symblai_nebula,
|
||||
|
199
libs/community/langchain_community/llms/mlx_pipeline.py
Normal file
199
libs/community/langchain_community/llms/mlx_pipeline.py
Normal file
@ -0,0 +1,199 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
DEFAULT_MODEL_ID = "mlx-community/quantized-gemma-2b"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLXPipeline(LLM):
|
||||
"""MLX Pipeline API.
|
||||
|
||||
To use, you should have the ``mlx-lm`` python package installed.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import MLXPipeline
|
||||
pipe = MLXPipeline.from_model_id(
|
||||
model_id="mlx-community/quantized-gemma-2b",
|
||||
pipeline_kwargs={"max_tokens": 10},
|
||||
)
|
||||
Example passing model and tokenizer in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import MLXPipeline
|
||||
from mlx_lm import load
|
||||
model_id="mlx-community/quantized-gemma-2b"
|
||||
model, tokenizer = load(model_id)
|
||||
pipe = MLXPipeline(model=model, tokenizer=tokenizer)
|
||||
"""
|
||||
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model: Any #: :meta private:
|
||||
"""Model."""
|
||||
tokenizer: Any #: :meta private:
|
||||
"""Tokenizer."""
|
||||
tokenizer_config: Optional[dict] = None
|
||||
"""
|
||||
Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
"""
|
||||
adapter_file: Optional[str] = None
|
||||
"""
|
||||
Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||
Defaults to None.
|
||||
"""
|
||||
lazy: bool = False
|
||||
"""
|
||||
If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
"""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments passed to the pipeline."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
model_id: str,
|
||||
tokenizer_config: Optional[dict] = None,
|
||||
adapter_file: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> MLXPipeline:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from mlx_lm import load
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import mlx_lm python package. "
|
||||
"Please install it with `pip install mlx_lm`."
|
||||
)
|
||||
|
||||
tokenizer_config = tokenizer_config or {}
|
||||
if adapter_file:
|
||||
model, tokenizer = load(model_id, tokenizer_config, adapter_file, lazy)
|
||||
else:
|
||||
model, tokenizer = load(model_id, tokenizer_config, lazy=lazy)
|
||||
|
||||
_pipeline_kwargs = pipeline_kwargs or {}
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_config=tokenizer_config,
|
||||
adapter_file=adapter_file,
|
||||
lazy=lazy,
|
||||
pipeline_kwargs=_pipeline_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"tokenizer_config": self.tokenizer_config,
|
||||
"adapter_file": self.adapter_file,
|
||||
"lazy": self.lazy,
|
||||
"pipeline_kwargs": self.pipeline_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mlx_pipeline"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
from mlx_lm import generate
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import mlx_lm python package. "
|
||||
"Please install it with `pip install mlx_lm`."
|
||||
)
|
||||
|
||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
|
||||
|
||||
return generate(self.model, self.tokenizer, prompt=prompt, **pipeline_kwargs)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import generate_step
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import mlx_lm python package. "
|
||||
"Please install it with `pip install mlx_lm`."
|
||||
)
|
||||
|
||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", self.pipeline_kwargs)
|
||||
|
||||
temp: float = pipeline_kwargs.get("temp", 0.0)
|
||||
max_new_tokens: int = pipeline_kwargs.get("max_tokens", 100)
|
||||
repetition_penalty: Optional[float] = pipeline_kwargs.get(
|
||||
"repetition_penalty", None
|
||||
)
|
||||
repetition_context_size: Optional[int] = pipeline_kwargs.get(
|
||||
"repetition_context_size", None
|
||||
)
|
||||
|
||||
prompt = self.tokenizer.encode(prompt, return_tensors="np")
|
||||
|
||||
prompt_tokens = mx.array(prompt[0])
|
||||
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
for (token, prob), n in zip(
|
||||
generate_step(
|
||||
prompt_tokens,
|
||||
self.model,
|
||||
temp,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
),
|
||||
range(max_new_tokens),
|
||||
):
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
text = self.tokenizer.decode(token.item())
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if token == eos_token_id or (stop is not None and text in stop):
|
||||
break
|
@ -0,0 +1,37 @@
|
||||
"""Test MLX Chat Model."""
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
from langchain_community.chat_models.mlx import ChatMLX
|
||||
from langchain_community.llms.mlx_pipeline import MLXPipeline
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test default model call."""
|
||||
llm = MLXPipeline.from_model_id(
|
||||
model_id="mlx-community/quantized-gemma-2b-it",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
chat = ChatMLX(llm=llm)
|
||||
response = chat.invoke(input=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
llm = MLXPipeline.from_model_id(
|
||||
model_id="mlx-community/quantized-gemma-2b-it",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
chat = ChatMLX(llm=llm)
|
||||
|
||||
response = chat.invoke(
|
||||
input=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
33
libs/community/tests/integration_tests/llms/test_mlx_pipeline.py
Executable file
33
libs/community/tests/integration_tests/llms/test_mlx_pipeline.py
Executable file
@ -0,0 +1,33 @@
|
||||
"""Test MLX Pipeline wrapper."""
|
||||
|
||||
from langchain_community.llms.mlx_pipeline import MLXPipeline
|
||||
|
||||
|
||||
def test_mlx_pipeline_text_generation() -> None:
|
||||
"""Test valid call to MLX text generation model."""
|
||||
llm = MLXPipeline.from_model_id(
|
||||
model_id="mlx-community/quantized-gemma-2b",
|
||||
pipeline_kwargs={"max_tokens": 10},
|
||||
)
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_init_with_model_and_tokenizer() -> None:
|
||||
"""Test initialization with a HF pipeline."""
|
||||
from mlx_lm import load
|
||||
|
||||
model, tokenizer = load("mlx-community/quantized-gemma-2b")
|
||||
llm = MLXPipeline(model=model, tokenizer=tokenizer)
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_pipeline_runtime_kwargs() -> None:
|
||||
"""Test pipelines specifying the device map parameter."""
|
||||
llm = MLXPipeline.from_model_id(
|
||||
model_id="mlx-community/quantized-gemma-2b",
|
||||
)
|
||||
prompt = "Say foo:"
|
||||
output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2})
|
||||
assert len(output) < 10
|
@ -16,6 +16,7 @@ EXPECTED_ALL = [
|
||||
"ChatMaritalk",
|
||||
"ChatMlflow",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatMLX",
|
||||
"ChatOllama",
|
||||
"ChatVertexAI",
|
||||
"JinaChat",
|
||||
|
11
libs/community/tests/unit_tests/chat_models/test_mlx.py
Normal file
11
libs/community/tests/unit_tests/chat_models/test_mlx.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Test MLX Chat wrapper."""
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def test_import_class() -> None:
|
||||
"""Test that the class can be imported."""
|
||||
module_name = "langchain_community.chat_models.mlx"
|
||||
class_name = "ChatMLX"
|
||||
|
||||
module = import_module(module_name)
|
||||
assert hasattr(module, class_name)
|
@ -52,6 +52,7 @@ EXPECT_ALL = [
|
||||
"Minimax",
|
||||
"Mlflow",
|
||||
"MlflowAIGateway",
|
||||
"MLXPipeline",
|
||||
"Modal",
|
||||
"MosaicML",
|
||||
"Nebula",
|
||||
|
Loading…
Reference in New Issue
Block a user