langchain/libs/community/tests/integration_tests/llms/test_mlx_pipeline.py
Jean-Philippe Dournel ebe38baaf9
community/mlx_pipeline: fix crash at mlx call (#29915)
- **Description:** 
Since mlx_lm 0.20, all calls to mlx crash due to deprecation of the way
parameters are passed to methods generate and generate_step.
Parameters top_p, temp, repetition_penalty and repetition_context_size
are not passed directly to those method anymore but wrapped into
"sampler" and "logit_processor".


- **Dependencies:** mlx_lm (optional)

-  **Tests:** 
I've had a new test to existing test file:
tests/integration_tests/llms/test_mlx_pipeline.py

---------

Co-authored-by: Jean-Philippe Dournel <jp@insightkeeper.io>
2025-02-21 09:14:53 -05:00

57 lines
1.7 KiB
Python
Executable File

"""Test MLX Pipeline wrapper."""
import pytest
from langchain_community.llms.mlx_pipeline import MLXPipeline
@pytest.mark.requires("mlx_lm")
def test_mlx_pipeline_text_generation() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={"max_tokens": 10},
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
@pytest.mark.requires("mlx_lm")
def test_init_with_model_and_tokenizer() -> None:
"""Test initialization with a HF pipeline."""
from mlx_lm import load
model, tokenizer = load("mlx-community/quantized-gemma-2b")
llm = MLXPipeline(model=model, tokenizer=tokenizer)
output = llm.invoke("Say foo:")
assert isinstance(output, str)
@pytest.mark.requires("mlx_lm")
def test_huggingface_pipeline_runtime_kwargs() -> None:
"""Test pipelines specifying the device map parameter."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
)
prompt = "Say foo:"
output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2})
assert len(output) < 10
@pytest.mark.requires("mlx_lm")
def test_mlx_pipeline_with_params() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={
"max_tokens": 10,
"temp": 0.8,
"verbose": False,
"repetition_penalty": 1.1,
"repetition_context_size": 64,
"top_p": 0.95,
},
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)