community/mlx_pipeline: fix crash at mlx call (#29915)

- **Description:** 
Since mlx_lm 0.20, all calls to mlx crash due to deprecation of the way
parameters are passed to methods generate and generate_step.
Parameters top_p, temp, repetition_penalty and repetition_context_size
are not passed directly to those method anymore but wrapped into
"sampler" and "logit_processor".


- **Dependencies:** mlx_lm (optional)

-  **Tests:** 
I've had a new test to existing test file:
tests/integration_tests/llms/test_mlx_pipeline.py

---------

Co-authored-by: Jean-Philippe Dournel <jp@insightkeeper.io>
This commit is contained in:
Jean-Philippe Dournel 2025-02-21 15:14:53 +01:00 committed by GitHub
parent bd773cffc3
commit ebe38baaf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 12 deletions

View File

@ -160,6 +160,7 @@ class ChatMLX(BaseChatModel):
try:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step
except ImportError:
@ -176,6 +177,9 @@ class ChatMLX(BaseChatModel):
repetition_context_size: Optional[int] = model_kwargs.get(
"repetition_context_size", None
)
top_p: float = model_kwargs.get("top_p", 1.0)
min_p: float = model_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = model_kwargs.get("min_tokens_to_keep", 1)
llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")
@ -183,13 +187,18 @@ class ChatMLX(BaseChatModel):
eos_token_id = self.tokenizer.eos_token_id
sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)
logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)
for (token, prob), n in zip(
generate_step(
prompt_tokens,
self.llm.model,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
sampler=sampler,
logits_processors=logits_processors,
),
range(max_new_tokens),
):

View File

@ -101,7 +101,9 @@ class MLXPipeline(LLM):
tokenizer_config = tokenizer_config or {}
if adapter_file:
model, tokenizer = load(model_id, tokenizer_config, adapter_file, lazy)
model, tokenizer = load(
model_id, tokenizer_config, adapter_path=adapter_file, lazy=lazy
)
else:
model, tokenizer = load(model_id, tokenizer_config, lazy=lazy)
@ -141,6 +143,7 @@ class MLXPipeline(LLM):
) -> str:
try:
from mlx_lm import generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError:
raise ImportError(
@ -161,18 +164,23 @@ class MLXPipeline(LLM):
"repetition_context_size", None
)
top_p: float = pipeline_kwargs.get("top_p", 1.0)
min_p: float = pipeline_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1)
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)
return generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
temp=temp,
max_tokens=max_tokens,
verbose=verbose,
formatter=formatter,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
top_p=top_p,
sampler=sampler,
logits_processors=logits_processors,
)
def _stream(
@ -184,6 +192,7 @@ class MLXPipeline(LLM):
) -> Iterator[GenerationChunk]:
try:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step
except ImportError:
@ -203,6 +212,8 @@ class MLXPipeline(LLM):
"repetition_context_size", None
)
top_p: float = pipeline_kwargs.get("top_p", 1.0)
min_p: float = pipeline_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1)
prompt = self.tokenizer.encode(prompt, return_tensors="np")
@ -212,14 +223,18 @@ class MLXPipeline(LLM):
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)
logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)
for (token, prob), n in zip(
generate_step(
prompt=prompt_tokens,
model=self.model,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
top_p=top_p,
sampler=sampler,
logits_processors=logits_processors,
),
range(max_new_tokens),
):

View File

@ -1,8 +1,11 @@
"""Test MLX Pipeline wrapper."""
import pytest
from langchain_community.llms.mlx_pipeline import MLXPipeline
@pytest.mark.requires("mlx_lm")
def test_mlx_pipeline_text_generation() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
@ -13,6 +16,7 @@ def test_mlx_pipeline_text_generation() -> None:
assert isinstance(output, str)
@pytest.mark.requires("mlx_lm")
def test_init_with_model_and_tokenizer() -> None:
"""Test initialization with a HF pipeline."""
from mlx_lm import load
@ -23,6 +27,7 @@ def test_init_with_model_and_tokenizer() -> None:
assert isinstance(output, str)
@pytest.mark.requires("mlx_lm")
def test_huggingface_pipeline_runtime_kwargs() -> None:
"""Test pipelines specifying the device map parameter."""
llm = MLXPipeline.from_model_id(
@ -31,3 +36,21 @@ def test_huggingface_pipeline_runtime_kwargs() -> None:
prompt = "Say foo:"
output = llm.invoke(prompt, pipeline_kwargs={"max_tokens": 2})
assert len(output) < 10
@pytest.mark.requires("mlx_lm")
def test_mlx_pipeline_with_params() -> None:
"""Test valid call to MLX text generation model."""
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b",
pipeline_kwargs={
"max_tokens": 10,
"temp": 0.8,
"verbose": False,
"repetition_penalty": 1.1,
"repetition_context_size": 64,
"top_p": 0.95,
},
)
output = llm.invoke("Say foo:")
assert isinstance(output, str)