community[minor]: Add support for MLX models (chat & llm) (#18152)

**Description:** This PR adds support for MLX models both chat (i.e.,
instruct) and llm (i.e., pretrained) types/
**Dependencies:** mlx, mlx_lm, transformers
**Twitter handle:** @Prince_Canuma

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Prince Canuma 2024-04-09 16:17:07 +02:00 committed by GitHub
parent 6baeaf4802
commit 1f9f4d8742
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 848 additions and 0 deletions

View File

@ -0,0 +1,217 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MLX\n",
"\n",
"This notebook shows how to get started using `MLX` LLM's as chat models.\n",
"\n",
"In particular, we will:\n",
"1. Utilize the [MLXPipeline](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/mlx_pipelines.py), \n",
"2. Utilize the `ChatMLX` class to enable any of these LLMs to interface with LangChain's [Chat Messages](https://python.langchain.com/docs/modules/model_io/chat/#messages) abstraction.\n",
"3. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet mlx-lm transformers huggingface_hub"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Instantiate an LLM\n",
"\n",
"There are three LLM options to choose from."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms.mlx_pipeline import MLXPipeline\n",
"\n",
"llm = MLXPipeline.from_model_id(\n",
" \"mlx-community/quantized-gemma-2b-it\",\n",
" pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Instantiate the `ChatMLX` to apply chat templates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Instantiate the chat model and some messages to pass."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import (\n",
" HumanMessage,\n",
")\n",
"from langchain_community.chat_models.mlx import ChatMLX\n",
"\n",
"messages = [\n",
" HumanMessage(\n",
" content=\"What happens when an unstoppable force meets an immovable object?\"\n",
" ),\n",
"]\n",
"\n",
"chat_model = ChatMLX(llm=llm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspect how the chat messages are formatted for the LLM call."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat_model._to_chat_prompt(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = chat_model.invoke(messages)\n",
"print(res.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Take it for a spin as an agent!\n",
"\n",
"Here we'll test out `gemma-2b-it` as a zero-shot `ReAct` Agent. The example below is taken from [here](https://python.langchain.com/docs/modules/agents/agent_types/react#using-chat-models).\n",
"\n",
"> Note: To run this section, you'll need to have a [SerpAPI Token](https://serpapi.com/) saved as an environment variable: `SERPAPI_API_KEY`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain import hub\n",
"from langchain.agents import AgentExecutor, load_tools\n",
"from langchain.agents.format_scratchpad import format_log_to_str\n",
"from langchain.agents.output_parsers import (\n",
" ReActJsonSingleInputOutputParser,\n",
")\n",
"from langchain.tools.render import render_text_description\n",
"from langchain_community.utilities import SerpAPIWrapper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Configure the agent with a `react-json` style prompt and access to a search engine and calculator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# setup tools\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
"\n",
"# setup ReAct style prompt\n",
"prompt = hub.pull(\"hwchase17/react-json\")\n",
"prompt = prompt.partial(\n",
" tools=render_text_description(tools),\n",
" tool_names=\", \".join([t.name for t in tools]),\n",
")\n",
"\n",
"# define the agent\n",
"chat_model_with_stop = chat_model.bind(stop=[\"\\nObservation\"])\n",
"agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_log_to_str(x[\"intermediate_steps\"]),\n",
" }\n",
" | prompt\n",
" | chat_model_with_stop\n",
" | ReActJsonSingleInputOutputParser()\n",
")\n",
"\n",
"# instantiate AgentExecutor\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent_executor.invoke(\n",
" {\n",
" \"input\": \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
" }\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,142 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "959300d4",
"metadata": {},
"source": [
"# MLX Local Pipelines\n",
"\n",
"MLX models can be run locally through the `MLXPipeline` class.\n",
"\n",
"The [MLX Community](https://huggingface.co/mlx-community) hosts over 150 models, all open source and publicly available on Hugging Face Model Hub a online platform where people can easily collaborate and build ML together.\n",
"\n",
"These can be called from LangChain either through this local pipeline wrapper or by calling their hosted inference endpoints through the MlXPipeline class. For more information on mlx, see the [examples repo](https://github.com/ml-explore/mlx-examples/tree/main/llms) notebook."
]
},
{
"cell_type": "markdown",
"id": "4c1b8450-5eaf-4d34-8341-2d785448a1ff",
"metadata": {
"tags": []
},
"source": [
"To use, you should have the ``mlx-lm`` python [package installed](https://pypi.org/project/mlx-lm/), as well as [transformers](https://pypi.org/project/transformers/). You can also install `huggingface_hub`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d772b637-de00-4663-bd77-9bc96d798db2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%pip install --upgrade --quiet mlx-lm transformers huggingface_hub"
]
},
{
"cell_type": "markdown",
"id": "91ad075f-71d5-4bc8-ab91-cc0ad5ef16bb",
"metadata": {},
"source": [
"### Model Loading\n",
"\n",
"Models can be loaded by specifying the model parameters using the `from_model_id` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "165ae236-962a-4763-8052-c4836d78a5d2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_community.llms.mlx_pipeline import MLXPipeline\n",
"\n",
"pipe = MLXPipeline.from_model_id(\n",
" \"mlx-community/quantized-gemma-2b-it\",\n",
" pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "00104b27-0c15-4a97-b198-4512337ee211",
"metadata": {},
"source": [
"They can also be loaded by passing in an existing `transformers` pipeline directly"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f426a4f",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline\n",
"from mlx_lm import load\n",
"\n",
"model, tokenizer = load(\"mlx-community/quantized-gemma-2b-it\")\n",
"pipe = MLXPipeline(model=model, tokenizer=tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "60e7ba8d",
"metadata": {},
"source": [
"### Create Chain\n",
"\n",
"With the model loaded into memory, you can compose it with a prompt to\n",
"form a chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3acf0069",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate.from_template(template)\n",
"\n",
"chain = prompt | pipe\n",
"\n",
"question = \"What is electroencephalography?\"\n",
"\n",
"print(chain.invoke({\"question\": question}))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -41,6 +41,7 @@ _module_lookup = {
"ChatLiteLLM": "langchain_community.chat_models.litellm",
"ChatLiteLLMRouter": "langchain_community.chat_models.litellm_router",
"ChatMLflowAIGateway": "langchain_community.chat_models.mlflow_ai_gateway",
"ChatMLX": "langchain_community.chat_models.mlx",
"ChatMaritalk": "langchain_community.chat_models.maritalk",
"ChatMlflow": "langchain_community.chat_models.mlflow",
"ChatOllama": "langchain_community.chat_models.ollama",

View File

@ -0,0 +1,196 @@
"""MLX Chat Wrapper."""
from typing import Any, Iterator, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_community.llms.mlx_pipeline import MLXPipeline
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
class ChatMLX(BaseChatModel):
"""
Wrapper for using MLX LLM's as ChatModels.
Works with `MLXPipeline` LLM.
To use, you should have the ``mlx-lm`` python package installed.
Example:
.. code-block:: python
from langchain_community.chat_models import chatMLX
from langchain_community.llms import MLXPipeline
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
)
chat = chatMLX(llm=llm)
"""
llm: MLXPipeline
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tokenizer = self.llm.tokenizer
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _to_chat_prompt(
self,
messages: List[BaseMessage],
tokenize: bool = False,
return_tensors: Optional[str] = None,
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")
if not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!")
messages_dicts = [self._to_chatml_format(m) for m in messages]
return self.tokenizer.apply_chat_template(
messages_dicts,
tokenize=tokenize,
add_generation_prompt=True,
return_tensors=return_tensors,
)
def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(f"Unknown message type: {type(message)}")
return {"role": role, "content": message.content}
@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []
for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
)
chat_generations.append(chat_generation)
return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)
@property
def _llm_type(self) -> str:
return "mlx-chat-wrapper"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
import mlx.core as mx
from mlx_lm.utils import generate_step
try:
import mlx.core as mx
from mlx_lm.utils import generate_step
except ImportError:
raise ValueError(
"Could not import mlx_lm python package. "
"Please install it with `pip install mlx_lm`."
)
model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs)
temp: float = model_kwargs.get("temp", 0.0)
max_new_tokens: int = model_kwargs.get("max_tokens", 100)
repetition_penalty: Optional[float] = model_kwargs.get(
"repetition_penalty", None
)
repetition_context_size: Optional[int] = model_kwargs.get(
"repetition_context_size", None
)
llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")
prompt_tokens = mx.array(llm_input[0])
eos_token_id = self.tokenizer.eos_token_id
for (token, prob), n in zip(
generate_step(
prompt_tokens,
self.llm.model,
temp,
repetition_penalty,
repetition_context_size,
),
range(max_new_tokens),
):
# identify text to yield
text: Optional[str] = None
text = self.tokenizer.decode(token.item())
# yield text, if any
if text:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
yield chunk
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
# break if stop sequence found
if token == eos_token_id or (stop is not None and text in stop):
break

View File

@ -356,6 +356,12 @@ def _import_mlflow_ai_gateway() -> Type[BaseLLM]:
return MlflowAIGateway
def _import_mlx_pipeline() -> Type[BaseLLM]:
from langchain_community.llms.mlx_pipeline import MLXPipeline
return MLXPipeline
def _import_modal() -> Type[BaseLLM]:
from langchain_community.llms.modal import Modal
@ -737,6 +743,8 @@ def __getattr__(name: str) -> Any:
return _import_mlflow()
elif name == "MlflowAIGateway":
return _import_mlflow_ai_gateway()
elif name == "MLXPipeline":
return _import_mlx_pipeline()
elif name == "Modal":
return _import_modal()
elif name == "MosaicML":
@ -887,6 +895,7 @@ __all__ = [
"Minimax",
"Mlflow",
"MlflowAIGateway",
"MLXPipeline",
"Modal",
"MosaicML",
"NIBittensorLLM",
@ -985,6 +994,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"mlflow": _import_mlflow,
"mlflow-chat": _import_mlflow_chat, # deprecated / only for back compat
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
"mlx_pipeline": _import_mlx_pipeline,
"modal": _import_modal,
"mosaic": _import_mosaicml,
"nebula": _import_symblai_nebula,

View File

@ -0,0 +1,199 @@
from __future__ import annotations
import logging
from typing import Any, Iterator, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra
DEFAULT_MODEL_ID = "mlx-community/quantized-gemma-2b"
logger = logging.getLogger(__name__)
class MLXPipeline(LLM):
"""MLX Pipeline API.
To use, you should have the ``mlx-lm`` python package installed.
Example using from_model_id:
.. code-block:: python
from langchain_community.llms import MLXPipeline
pipe = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={"max_tokens": 10},
)
Example passing model and tokenizer in directly:
.. code-block:: python
from langchain_community.llms import MLXPipeline
from mlx_lm import load
model_id="mlx-community/quantized-gemma-2b"
model, tokenizer = load(model_id)
pipe = MLXPipeline(model=model, tokenizer=tokenizer)
"""
model_id: str = DEFAULT_MODEL_ID
"""Model name to use."""
model: Any #: :meta private:
"""Model."""
tokenizer: Any #: :meta private:
"""Tokenizer."""
tokenizer_config: Optional[dict] = None
"""
Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
"""
adapter_file: Optional[str] = None
"""
Path to the adapter file. If provided, applies LoRA layers to the model.
Defaults to None.
"""
lazy: bool = False
"""
If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
"""
pipeline_kwargs: Optional[dict] = None
"""Keyword arguments passed to the pipeline."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@classmethod
def from_model_id(
cls,
model_id: str,
tokenizer_config: Optional[dict] = None,
adapter_file: Optional[str] = None,
lazy: bool = False,
pipeline_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> MLXPipeline:
"""Construct the pipeline object from model_id and task."""
try:
from mlx_lm import load
except ImportError:
raise ValueError(
"Could not import mlx_lm python package. "
"Please install it with `pip install mlx_lm`."
)
tokenizer_config = tokenizer_config or {}
if adapter_file:
model, tokenizer = load(model_id, tokenizer_config, adapter_file, lazy)
else:
model, tokenizer = load(model_id, tokenizer_config, lazy=lazy)
_pipeline_kwargs = pipeline_kwargs or {}
return cls(
model_id=model_id,
model=model,
tokenizer=tokenizer,
tokenizer_config=tokenizer_config,
adapter_file=adapter_file,
lazy=lazy,
pipeline_kwargs=_pipeline_kwargs,
**kwargs,
)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_id": self.model_id,
"tokenizer_config": self.tokenizer_config,
"adapter_file": self.adapter_file,
"lazy": self.lazy,
"pipeline_kwargs": self.pipeline_kwargs,
}
@property
def _llm_type(self) -> str:
return "mlx_pipeline"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
from mlx_lm import generate
except ImportError:
raise ValueError(
"Could not import mlx_lm python package. "
"Please install it with `pip install mlx_lm`."
)
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
return generate(self.model, self.tokenizer, prompt=prompt, **pipeline_kwargs)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
try:
import mlx.core as mx
from mlx_lm.utils import generate_step
except ImportError:
raise ValueError(
"Could not import mlx_lm python package. "
"Please install it with `pip install mlx_lm`."
)
pipeline_kwargs = kwargs.get("pipeline_kwargs", self.pipeline_kwargs)
temp: float = pipeline_kwargs.get("temp", 0.0)
max_new_tokens: int = pipeline_kwargs.get("max_tokens", 100)
repetition_penalty: Optional[float] = pipeline_kwargs.get(
"repetition_penalty", None
)
repetition_context_size: Optional[int] = pipeline_kwargs.get(
"repetition_context_size", None
)
prompt = self.tokenizer.encode(prompt, return_tensors="np")
prompt_tokens = mx.array(prompt[0])
eos_token_id = self.tokenizer.eos_token_id
for (token, prob), n in zip(
generate_step(
prompt_tokens,
self.model,
temp,
repetition_penalty,
repetition_context_size,
),
range(max_new_tokens),
):
# identify text to yield
text: Optional[str] = None
text = self.tokenizer.decode(token.item())
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if token == eos_token_id or (stop is not None and text in stop):
break

View File

@ -0,0 +1,37 @@
"""Test MLX Chat Model."""
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_community.chat_models.mlx import ChatMLX
from langchain_community.llms.mlx_pipeline import MLXPipeline
def test_default_call() -> None:
"""Test default model call."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
pipeline_kwargs={"max_new_tokens": 10},
)
chat = ChatMLX(llm=llm)
response = chat.invoke(input=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
pipeline_kwargs={"max_new_tokens": 10},
)
chat = ChatMLX(llm=llm)
response = chat.invoke(
input=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)

View File

@ -0,0 +1,33 @@
"""Test MLX Pipeline wrapper."""
from langchain_community.llms.mlx_pipeline import MLXPipeline
def test_mlx_pipeline_text_generation() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={"max_tokens": 10},
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
def test_init_with_model_and_tokenizer() -> None:
"""Test initialization with a HF pipeline."""
from mlx_lm import load
model, tokenizer = load("mlx-community/quantized-gemma-2b")
llm = MLXPipeline(model=model, tokenizer=tokenizer)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
def test_huggingface_pipeline_runtime_kwargs() -> None:
"""Test pipelines specifying the device map parameter."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
)
prompt = "Say foo:"
output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2})
assert len(output) < 10

View File

@ -16,6 +16,7 @@ EXPECTED_ALL = [
"ChatMaritalk",
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatMLX",
"ChatOllama",
"ChatVertexAI",
"JinaChat",

View File

@ -0,0 +1,11 @@
"""Test MLX Chat wrapper."""
from importlib import import_module
def test_import_class() -> None:
"""Test that the class can be imported."""
module_name = "langchain_community.chat_models.mlx"
class_name = "ChatMLX"
module = import_module(module_name)
assert hasattr(module, class_name)

View File

@ -52,6 +52,7 @@ EXPECT_ALL = [
"Minimax",
"Mlflow",
"MlflowAIGateway",
"MLXPipeline",
"Modal",
"MosaicML",
"Nebula",