diff --git a/docs/docs/integrations/chat/mlx.ipynb b/docs/docs/integrations/chat/mlx.ipynb new file mode 100644 index 00000000000..07a4cc638fc --- /dev/null +++ b/docs/docs/integrations/chat/mlx.ipynb @@ -0,0 +1,217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MLX\n", + "\n", + "This notebook shows how to get started using `MLX` LLM's as chat models.\n", + "\n", + "In particular, we will:\n", + "1. Utilize the [MLXPipeline](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/mlx_pipelines.py), \n", + "2. Utilize the `ChatMLX` class to enable any of these LLMs to interface with LangChain's [Chat Messages](https://python.langchain.com/docs/modules/model_io/chat/#messages) abstraction.\n", + "3. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet mlx-lm transformers huggingface_hub" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Instantiate an LLM\n", + "\n", + "There are three LLM options to choose from." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms.mlx_pipeline import MLXPipeline\n", + "\n", + "llm = MLXPipeline.from_model_id(\n", + " \"mlx-community/quantized-gemma-2b-it\",\n", + " pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Instantiate the `ChatMLX` to apply chat templates" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instantiate the chat model and some messages to pass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import (\n", + " HumanMessage,\n", + ")\n", + "from langchain_community.chat_models.mlx import ChatMLX\n", + "\n", + "messages = [\n", + " HumanMessage(\n", + " content=\"What happens when an unstoppable force meets an immovable object?\"\n", + " ),\n", + "]\n", + "\n", + "chat_model = ChatMLX(llm=llm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspect how the chat messages are formatted for the LLM call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chat_model._to_chat_prompt(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Call the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = chat_model.invoke(messages)\n", + "print(res.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Take it for a spin as an agent!\n", + "\n", + "Here we'll test out `gemma-2b-it` as a zero-shot `ReAct` Agent. The example below is taken from [here](https://python.langchain.com/docs/modules/agents/agent_types/react#using-chat-models).\n", + "\n", + "> Note: To run this section, you'll need to have a [SerpAPI Token](https://serpapi.com/) saved as an environment variable: `SERPAPI_API_KEY`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import hub\n", + "from langchain.agents import AgentExecutor, load_tools\n", + "from langchain.agents.format_scratchpad import format_log_to_str\n", + "from langchain.agents.output_parsers import (\n", + " ReActJsonSingleInputOutputParser,\n", + ")\n", + "from langchain.tools.render import render_text_description\n", + "from langchain_community.utilities import SerpAPIWrapper" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configure the agent with a `react-json` style prompt and access to a search engine and calculator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# setup tools\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n", + "\n", + "# setup ReAct style prompt\n", + "prompt = hub.pull(\"hwchase17/react-json\")\n", + "prompt = prompt.partial(\n", + " tools=render_text_description(tools),\n", + " tool_names=\", \".join([t.name for t in tools]),\n", + ")\n", + "\n", + "# define the agent\n", + "chat_model_with_stop = chat_model.bind(stop=[\"\\nObservation\"])\n", + "agent = (\n", + " {\n", + " \"input\": lambda x: x[\"input\"],\n", + " \"agent_scratchpad\": lambda x: format_log_to_str(x[\"intermediate_steps\"]),\n", + " }\n", + " | prompt\n", + " | chat_model_with_stop\n", + " | ReActJsonSingleInputOutputParser()\n", + ")\n", + "\n", + "# instantiate AgentExecutor\n", + "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent_executor.invoke(\n", + " {\n", + " \"input\": \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n", + " }\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/integrations/llms/mlx_pipelines.ipynb b/docs/docs/integrations/llms/mlx_pipelines.ipynb new file mode 100644 index 00000000000..7ea22f1df30 --- /dev/null +++ b/docs/docs/integrations/llms/mlx_pipelines.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "959300d4", + "metadata": {}, + "source": [ + "# MLX Local Pipelines\n", + "\n", + "MLX models can be run locally through the `MLXPipeline` class.\n", + "\n", + "The [MLX Community](https://huggingface.co/mlx-community) hosts over 150 models, all open source and publicly available on Hugging Face Model Hub a online platform where people can easily collaborate and build ML together.\n", + "\n", + "These can be called from LangChain either through this local pipeline wrapper or by calling their hosted inference endpoints through the MlXPipeline class. For more information on mlx, see the [examples repo](https://github.com/ml-explore/mlx-examples/tree/main/llms) notebook." + ] + }, + { + "cell_type": "markdown", + "id": "4c1b8450-5eaf-4d34-8341-2d785448a1ff", + "metadata": { + "tags": [] + }, + "source": [ + "To use, you should have the ``mlx-lm`` python [package installed](https://pypi.org/project/mlx-lm/), as well as [transformers](https://pypi.org/project/transformers/). You can also install `huggingface_hub`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d772b637-de00-4663-bd77-9bc96d798db2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet mlx-lm transformers huggingface_hub" + ] + }, + { + "cell_type": "markdown", + "id": "91ad075f-71d5-4bc8-ab91-cc0ad5ef16bb", + "metadata": {}, + "source": [ + "### Model Loading\n", + "\n", + "Models can be loaded by specifying the model parameters using the `from_model_id` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "165ae236-962a-4763-8052-c4836d78a5d2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain_community.llms.mlx_pipeline import MLXPipeline\n", + "\n", + "pipe = MLXPipeline.from_model_id(\n", + " \"mlx-community/quantized-gemma-2b-it\",\n", + " pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "00104b27-0c15-4a97-b198-4512337ee211", + "metadata": {}, + "source": [ + "They can also be loaded by passing in an existing `transformers` pipeline directly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f426a4f", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline\n", + "from mlx_lm import load\n", + "\n", + "model, tokenizer = load(\"mlx-community/quantized-gemma-2b-it\")\n", + "pipe = MLXPipeline(model=model, tokenizer=tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "60e7ba8d", + "metadata": {}, + "source": [ + "### Create Chain\n", + "\n", + "With the model loaded into memory, you can compose it with a prompt to\n", + "form a chain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3acf0069", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = PromptTemplate.from_template(template)\n", + "\n", + "chain = prompt | pipe\n", + "\n", + "question = \"What is electroencephalography?\"\n", + "\n", + "print(chain.invoke({\"question\": question}))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 18f202813b5..acfaf20b164 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -41,6 +41,7 @@ _module_lookup = { "ChatLiteLLM": "langchain_community.chat_models.litellm", "ChatLiteLLMRouter": "langchain_community.chat_models.litellm_router", "ChatMLflowAIGateway": "langchain_community.chat_models.mlflow_ai_gateway", + "ChatMLX": "langchain_community.chat_models.mlx", "ChatMaritalk": "langchain_community.chat_models.maritalk", "ChatMlflow": "langchain_community.chat_models.mlflow", "ChatOllama": "langchain_community.chat_models.ollama", diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py new file mode 100644 index 00000000000..e6f2b70473d --- /dev/null +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -0,0 +1,196 @@ +"""MLX Chat Wrapper.""" + +from typing import Any, Iterator, List, Optional + +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, + LLMResult, +) + +from langchain_community.llms.mlx_pipeline import MLXPipeline + +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" + + +class ChatMLX(BaseChatModel): + """ + Wrapper for using MLX LLM's as ChatModels. + + Works with `MLXPipeline` LLM. + + To use, you should have the ``mlx-lm`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.chat_models import chatMLX + from langchain_community.llms import MLXPipeline + + llm = MLXPipeline.from_model_id( + model_id="mlx-community/quantized-gemma-2b-it", + ) + chat = chatMLX(llm=llm) + + """ + + llm: MLXPipeline + system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) + tokenizer: Any = None + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.tokenizer = self.llm.tokenizer + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + llm_input = self._to_chat_prompt(messages) + llm_result = self.llm._generate( + prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs + ) + return self._to_chat_result(llm_result) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + llm_input = self._to_chat_prompt(messages) + llm_result = await self.llm._agenerate( + prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs + ) + return self._to_chat_result(llm_result) + + def _to_chat_prompt( + self, + messages: List[BaseMessage], + tokenize: bool = False, + return_tensors: Optional[str] = None, + ) -> str: + """Convert a list of messages into a prompt format expected by wrapped LLM.""" + if not messages: + raise ValueError("At least one HumanMessage must be provided!") + + if not isinstance(messages[-1], HumanMessage): + raise ValueError("Last message must be a HumanMessage!") + + messages_dicts = [self._to_chatml_format(m) for m in messages] + + return self.tokenizer.apply_chat_template( + messages_dicts, + tokenize=tokenize, + add_generation_prompt=True, + return_tensors=return_tensors, + ) + + def _to_chatml_format(self, message: BaseMessage) -> dict: + """Convert LangChain message to ChatML format.""" + + if isinstance(message, SystemMessage): + role = "system" + elif isinstance(message, AIMessage): + role = "assistant" + elif isinstance(message, HumanMessage): + role = "user" + else: + raise ValueError(f"Unknown message type: {type(message)}") + + return {"role": role, "content": message.content} + + @staticmethod + def _to_chat_result(llm_result: LLMResult) -> ChatResult: + chat_generations = [] + + for g in llm_result.generations[0]: + chat_generation = ChatGeneration( + message=AIMessage(content=g.text), generation_info=g.generation_info + ) + chat_generations.append(chat_generation) + + return ChatResult( + generations=chat_generations, llm_output=llm_result.llm_output + ) + + @property + def _llm_type(self) -> str: + return "mlx-chat-wrapper" + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + import mlx.core as mx + from mlx_lm.utils import generate_step + + try: + import mlx.core as mx + from mlx_lm.utils import generate_step + + except ImportError: + raise ValueError( + "Could not import mlx_lm python package. " + "Please install it with `pip install mlx_lm`." + ) + model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs) + temp: float = model_kwargs.get("temp", 0.0) + max_new_tokens: int = model_kwargs.get("max_tokens", 100) + repetition_penalty: Optional[float] = model_kwargs.get( + "repetition_penalty", None + ) + repetition_context_size: Optional[int] = model_kwargs.get( + "repetition_context_size", None + ) + + llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np") + + prompt_tokens = mx.array(llm_input[0]) + + eos_token_id = self.tokenizer.eos_token_id + + for (token, prob), n in zip( + generate_step( + prompt_tokens, + self.llm.model, + temp, + repetition_penalty, + repetition_context_size, + ), + range(max_new_tokens), + ): + # identify text to yield + text: Optional[str] = None + text = self.tokenizer.decode(token.item()) + + # yield text, if any + if text: + chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) + yield chunk + if run_manager: + run_manager.on_llm_new_token(text, chunk=chunk) + + # break if stop sequence found + if token == eos_token_id or (stop is not None and text in stop): + break diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index f51ee89b727..dd77fd8ce85 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -356,6 +356,12 @@ def _import_mlflow_ai_gateway() -> Type[BaseLLM]: return MlflowAIGateway +def _import_mlx_pipeline() -> Type[BaseLLM]: + from langchain_community.llms.mlx_pipeline import MLXPipeline + + return MLXPipeline + + def _import_modal() -> Type[BaseLLM]: from langchain_community.llms.modal import Modal @@ -737,6 +743,8 @@ def __getattr__(name: str) -> Any: return _import_mlflow() elif name == "MlflowAIGateway": return _import_mlflow_ai_gateway() + elif name == "MLXPipeline": + return _import_mlx_pipeline() elif name == "Modal": return _import_modal() elif name == "MosaicML": @@ -887,6 +895,7 @@ __all__ = [ "Minimax", "Mlflow", "MlflowAIGateway", + "MLXPipeline", "Modal", "MosaicML", "NIBittensorLLM", @@ -985,6 +994,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "mlflow": _import_mlflow, "mlflow-chat": _import_mlflow_chat, # deprecated / only for back compat "mlflow-ai-gateway": _import_mlflow_ai_gateway, + "mlx_pipeline": _import_mlx_pipeline, "modal": _import_modal, "mosaic": _import_mosaicml, "nebula": _import_symblai_nebula, diff --git a/libs/community/langchain_community/llms/mlx_pipeline.py b/libs/community/langchain_community/llms/mlx_pipeline.py new file mode 100644 index 00000000000..8445fc955a9 --- /dev/null +++ b/libs/community/langchain_community/llms/mlx_pipeline.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import logging +from typing import Any, Iterator, List, Mapping, Optional + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk +from langchain_core.pydantic_v1 import Extra + +DEFAULT_MODEL_ID = "mlx-community/quantized-gemma-2b" + +logger = logging.getLogger(__name__) + + +class MLXPipeline(LLM): + """MLX Pipeline API. + + To use, you should have the ``mlx-lm`` python package installed. + + Example using from_model_id: + .. code-block:: python + + from langchain_community.llms import MLXPipeline + pipe = MLXPipeline.from_model_id( + model_id="mlx-community/quantized-gemma-2b", + pipeline_kwargs={"max_tokens": 10}, + ) + Example passing model and tokenizer in directly: + .. code-block:: python + + from langchain_community.llms import MLXPipeline + from mlx_lm import load + model_id="mlx-community/quantized-gemma-2b" + model, tokenizer = load(model_id) + pipe = MLXPipeline(model=model, tokenizer=tokenizer) + """ + + model_id: str = DEFAULT_MODEL_ID + """Model name to use.""" + model: Any #: :meta private: + """Model.""" + tokenizer: Any #: :meta private: + """Tokenizer.""" + tokenizer_config: Optional[dict] = None + """ + Configuration parameters specifically for the tokenizer. + Defaults to an empty dictionary. + """ + adapter_file: Optional[str] = None + """ + Path to the adapter file. If provided, applies LoRA layers to the model. + Defaults to None. + """ + lazy: bool = False + """ + If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + """ + pipeline_kwargs: Optional[dict] = None + """Keyword arguments passed to the pipeline.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @classmethod + def from_model_id( + cls, + model_id: str, + tokenizer_config: Optional[dict] = None, + adapter_file: Optional[str] = None, + lazy: bool = False, + pipeline_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> MLXPipeline: + """Construct the pipeline object from model_id and task.""" + try: + from mlx_lm import load + + except ImportError: + raise ValueError( + "Could not import mlx_lm python package. " + "Please install it with `pip install mlx_lm`." + ) + + tokenizer_config = tokenizer_config or {} + if adapter_file: + model, tokenizer = load(model_id, tokenizer_config, adapter_file, lazy) + else: + model, tokenizer = load(model_id, tokenizer_config, lazy=lazy) + + _pipeline_kwargs = pipeline_kwargs or {} + return cls( + model_id=model_id, + model=model, + tokenizer=tokenizer, + tokenizer_config=tokenizer_config, + adapter_file=adapter_file, + lazy=lazy, + pipeline_kwargs=_pipeline_kwargs, + **kwargs, + ) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model_id": self.model_id, + "tokenizer_config": self.tokenizer_config, + "adapter_file": self.adapter_file, + "lazy": self.lazy, + "pipeline_kwargs": self.pipeline_kwargs, + } + + @property + def _llm_type(self) -> str: + return "mlx_pipeline" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + try: + from mlx_lm import generate + + except ImportError: + raise ValueError( + "Could not import mlx_lm python package. " + "Please install it with `pip install mlx_lm`." + ) + + pipeline_kwargs = kwargs.get("pipeline_kwargs", {}) + + return generate(self.model, self.tokenizer, prompt=prompt, **pipeline_kwargs) + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + try: + import mlx.core as mx + from mlx_lm.utils import generate_step + + except ImportError: + raise ValueError( + "Could not import mlx_lm python package. " + "Please install it with `pip install mlx_lm`." + ) + + pipeline_kwargs = kwargs.get("pipeline_kwargs", self.pipeline_kwargs) + + temp: float = pipeline_kwargs.get("temp", 0.0) + max_new_tokens: int = pipeline_kwargs.get("max_tokens", 100) + repetition_penalty: Optional[float] = pipeline_kwargs.get( + "repetition_penalty", None + ) + repetition_context_size: Optional[int] = pipeline_kwargs.get( + "repetition_context_size", None + ) + + prompt = self.tokenizer.encode(prompt, return_tensors="np") + + prompt_tokens = mx.array(prompt[0]) + + eos_token_id = self.tokenizer.eos_token_id + + for (token, prob), n in zip( + generate_step( + prompt_tokens, + self.model, + temp, + repetition_penalty, + repetition_context_size, + ), + range(max_new_tokens), + ): + # identify text to yield + text: Optional[str] = None + text = self.tokenizer.decode(token.item()) + + # yield text, if any + if text: + chunk = GenerationChunk(text=text) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text) + + # break if stop sequence found + if token == eos_token_id or (stop is not None and text in stop): + break diff --git a/libs/community/tests/integration_tests/chat_models/text_mlx.py b/libs/community/tests/integration_tests/chat_models/text_mlx.py new file mode 100644 index 00000000000..00b0ea57a9d --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/text_mlx.py @@ -0,0 +1,37 @@ +"""Test MLX Chat Model.""" + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage + +from langchain_community.chat_models.mlx import ChatMLX +from langchain_community.llms.mlx_pipeline import MLXPipeline + + +def test_default_call() -> None: + """Test default model call.""" + llm = MLXPipeline.from_model_id( + model_id="mlx-community/quantized-gemma-2b-it", + pipeline_kwargs={"max_new_tokens": 10}, + ) + chat = ChatMLX(llm=llm) + response = chat.invoke(input=[HumanMessage(content="Hello")]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_multiple_history() -> None: + """Tests multiple history works.""" + llm = MLXPipeline.from_model_id( + model_id="mlx-community/quantized-gemma-2b-it", + pipeline_kwargs={"max_new_tokens": 10}, + ) + chat = ChatMLX(llm=llm) + + response = chat.invoke( + input=[ + HumanMessage(content="Hello."), + AIMessage(content="Hello!"), + HumanMessage(content="How are you doing?"), + ] + ) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) diff --git a/libs/community/tests/integration_tests/llms/test_mlx_pipeline.py b/libs/community/tests/integration_tests/llms/test_mlx_pipeline.py new file mode 100755 index 00000000000..92179cfe416 --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_mlx_pipeline.py @@ -0,0 +1,33 @@ +"""Test MLX Pipeline wrapper.""" + +from langchain_community.llms.mlx_pipeline import MLXPipeline + + +def test_mlx_pipeline_text_generation() -> None: + """Test valid call to MLX text generation model.""" + llm = MLXPipeline.from_model_id( + model_id="mlx-community/quantized-gemma-2b", + pipeline_kwargs={"max_tokens": 10}, + ) + output = llm.invoke("Say foo:") + assert isinstance(output, str) + + +def test_init_with_model_and_tokenizer() -> None: + """Test initialization with a HF pipeline.""" + from mlx_lm import load + + model, tokenizer = load("mlx-community/quantized-gemma-2b") + llm = MLXPipeline(model=model, tokenizer=tokenizer) + output = llm.invoke("Say foo:") + assert isinstance(output, str) + + +def test_huggingface_pipeline_runtime_kwargs() -> None: + """Test pipelines specifying the device map parameter.""" + llm = MLXPipeline.from_model_id( + model_id="mlx-community/quantized-gemma-2b", + ) + prompt = "Say foo:" + output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2}) + assert len(output) < 10 diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index cca1330eaa5..5bffff9de0b 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -16,6 +16,7 @@ EXPECTED_ALL = [ "ChatMaritalk", "ChatMlflow", "ChatMLflowAIGateway", + "ChatMLX", "ChatOllama", "ChatVertexAI", "JinaChat", diff --git a/libs/community/tests/unit_tests/chat_models/test_mlx.py b/libs/community/tests/unit_tests/chat_models/test_mlx.py new file mode 100644 index 00000000000..5add10a4356 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_mlx.py @@ -0,0 +1,11 @@ +"""Test MLX Chat wrapper.""" +from importlib import import_module + + +def test_import_class() -> None: + """Test that the class can be imported.""" + module_name = "langchain_community.chat_models.mlx" + class_name = "ChatMLX" + + module = import_module(module_name) + assert hasattr(module, class_name) diff --git a/libs/community/tests/unit_tests/llms/test_imports.py b/libs/community/tests/unit_tests/llms/test_imports.py index a4e83da31d4..64cfaec9c50 100644 --- a/libs/community/tests/unit_tests/llms/test_imports.py +++ b/libs/community/tests/unit_tests/llms/test_imports.py @@ -52,6 +52,7 @@ EXPECT_ALL = [ "Minimax", "Mlflow", "MlflowAIGateway", + "MLXPipeline", "Modal", "MosaicML", "Nebula",