mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
community/mlx_pipeline: fix crash at mlx call (#29915)
- **Description:** Since mlx_lm 0.20, all calls to mlx crash due to deprecation of the way parameters are passed to methods generate and generate_step. Parameters top_p, temp, repetition_penalty and repetition_context_size are not passed directly to those method anymore but wrapped into "sampler" and "logit_processor". - **Dependencies:** mlx_lm (optional) - **Tests:** I've had a new test to existing test file: tests/integration_tests/llms/test_mlx_pipeline.py --------- Co-authored-by: Jean-Philippe Dournel <jp@insightkeeper.io>
This commit is contained in:
committed by
GitHub
parent
bd773cffc3
commit
ebe38baaf9
@@ -101,7 +101,9 @@ class MLXPipeline(LLM):
|
||||
|
||||
tokenizer_config = tokenizer_config or {}
|
||||
if adapter_file:
|
||||
model, tokenizer = load(model_id, tokenizer_config, adapter_file, lazy)
|
||||
model, tokenizer = load(
|
||||
model_id, tokenizer_config, adapter_path=adapter_file, lazy=lazy
|
||||
)
|
||||
else:
|
||||
model, tokenizer = load(model_id, tokenizer_config, lazy=lazy)
|
||||
|
||||
@@ -141,6 +143,7 @@ class MLXPipeline(LLM):
|
||||
) -> str:
|
||||
try:
|
||||
from mlx_lm import generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -161,18 +164,23 @@ class MLXPipeline(LLM):
|
||||
"repetition_context_size", None
|
||||
)
|
||||
top_p: float = pipeline_kwargs.get("top_p", 1.0)
|
||||
min_p: float = pipeline_kwargs.get("min_p", 0.0)
|
||||
min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1)
|
||||
|
||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||
logits_processors = make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size
|
||||
)
|
||||
|
||||
return generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
temp=temp,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose,
|
||||
formatter=formatter,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
top_p=top_p,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
)
|
||||
|
||||
def _stream(
|
||||
@@ -184,6 +192,7 @@ class MLXPipeline(LLM):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
from mlx_lm.utils import generate_step
|
||||
|
||||
except ImportError:
|
||||
@@ -203,6 +212,8 @@ class MLXPipeline(LLM):
|
||||
"repetition_context_size", None
|
||||
)
|
||||
top_p: float = pipeline_kwargs.get("top_p", 1.0)
|
||||
min_p: float = pipeline_kwargs.get("min_p", 0.0)
|
||||
min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1)
|
||||
|
||||
prompt = self.tokenizer.encode(prompt, return_tensors="np")
|
||||
|
||||
@@ -212,14 +223,18 @@ class MLXPipeline(LLM):
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
|
||||
sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)
|
||||
|
||||
logits_processors = make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size
|
||||
)
|
||||
|
||||
for (token, prob), n in zip(
|
||||
generate_step(
|
||||
prompt=prompt_tokens,
|
||||
model=self.model,
|
||||
temp=temp,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
top_p=top_p,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
),
|
||||
range(max_new_tokens),
|
||||
):
|
||||
|
Reference in New Issue
Block a user