community[minor]: Add support for MLX models (chat & llm) (#18152)

**Description:** This PR adds support for MLX models both chat (i.e.,
instruct) and llm (i.e., pretrained) types/
**Dependencies:** mlx, mlx_lm, transformers
**Twitter handle:** @Prince_Canuma

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Prince Canuma
2024-04-09 16:17:07 +02:00
committed by GitHub
parent 6baeaf4802
commit 1f9f4d8742
11 changed files with 848 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
"""Test MLX Chat Model."""
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_community.chat_models.mlx import ChatMLX
from langchain_community.llms.mlx_pipeline import MLXPipeline
def test_default_call() -> None:
"""Test default model call."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
pipeline_kwargs={"max_new_tokens": 10},
)
chat = ChatMLX(llm=llm)
response = chat.invoke(input=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
pipeline_kwargs={"max_new_tokens": 10},
)
chat = ChatMLX(llm=llm)
response = chat.invoke(
input=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)

View File

@@ -0,0 +1,33 @@
"""Test MLX Pipeline wrapper."""
from langchain_community.llms.mlx_pipeline import MLXPipeline
def test_mlx_pipeline_text_generation() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={"max_tokens": 10},
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
def test_init_with_model_and_tokenizer() -> None:
"""Test initialization with a HF pipeline."""
from mlx_lm import load
model, tokenizer = load("mlx-community/quantized-gemma-2b")
llm = MLXPipeline(model=model, tokenizer=tokenizer)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
def test_huggingface_pipeline_runtime_kwargs() -> None:
"""Test pipelines specifying the device map parameter."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
)
prompt = "Say foo:"
output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2})
assert len(output) < 10