langchain/libs/community/langchain_community/chat_models/mlx.py
Jean-Philippe Dournel ebe38baaf9
community/mlx_pipeline: fix crash at mlx call (#29915)
- **Description:** 
Since mlx_lm 0.20, all calls to mlx crash due to deprecation of the way
parameters are passed to methods generate and generate_step.
Parameters top_p, temp, repetition_penalty and repetition_context_size
are not passed directly to those method anymore but wrapped into
"sampler" and "logit_processor".


- **Dependencies:** mlx_lm (optional)

-  **Tests:** 
I've had a new test to existing test file:
tests/integration_tests/llms/test_mlx_pipeline.py

---------

Co-authored-by: Jean-Philippe Dournel <jp@insightkeeper.io>
2025-02-21 09:14:53 -05:00

278 lines
9.4 KiB
Python

"""MLX Chat Wrapper."""
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_community.llms.mlx_pipeline import MLXPipeline
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
class ChatMLX(BaseChatModel):
"""MLX chat models.
Works with `MLXPipeline` LLM.
To use, you should have the ``mlx-lm`` python package installed.
Example:
.. code-block:: python
from langchain_community.chat_models import chatMLX
from langchain_community.llms import MLXPipeline
llm = MLXPipeline.from_model_id(
model_id="mlx-community/quantized-gemma-2b-it",
)
chat = chatMLX(llm=llm)
"""
llm: MLXPipeline
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tokenizer = self.llm.tokenizer
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _to_chat_prompt(
self,
messages: List[BaseMessage],
tokenize: bool = False,
return_tensors: Optional[str] = None,
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")
if not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!")
messages_dicts = [self._to_chatml_format(m) for m in messages]
return self.tokenizer.apply_chat_template(
messages_dicts,
tokenize=tokenize,
add_generation_prompt=True,
return_tensors=return_tensors,
)
def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(f"Unknown message type: {type(message)}")
return {"role": role, "content": message.content}
@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []
for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
)
chat_generations.append(chat_generation)
return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)
@property
def _llm_type(self) -> str:
return "mlx-chat-wrapper"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
import mlx.core as mx
from mlx_lm.utils import generate_step
try:
import mlx.core as mx
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate_step
except ImportError:
raise ImportError(
"Could not import mlx_lm python package. "
"Please install it with `pip install mlx_lm`."
)
model_kwargs = kwargs.get("model_kwargs", self.llm.pipeline_kwargs)
temp: float = model_kwargs.get("temp", 0.0)
max_new_tokens: int = model_kwargs.get("max_tokens", 100)
repetition_penalty: Optional[float] = model_kwargs.get(
"repetition_penalty", None
)
repetition_context_size: Optional[int] = model_kwargs.get(
"repetition_context_size", None
)
top_p: float = model_kwargs.get("top_p", 1.0)
min_p: float = model_kwargs.get("min_p", 0.0)
min_tokens_to_keep: int = model_kwargs.get("min_tokens_to_keep", 1)
llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np")
prompt_tokens = mx.array(llm_input[0])
eos_token_id = self.tokenizer.eos_token_id
sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep)
logits_processors = make_logits_processors(
None, repetition_penalty, repetition_context_size
)
for (token, prob), n in zip(
generate_step(
prompt_tokens,
self.llm.model,
sampler=sampler,
logits_processors=logits_processors,
),
range(max_new_tokens),
):
# identify text to yield
text: Optional[str] = None
if not isinstance(token, int):
text = self.tokenizer.decode(token.item())
else:
text = self.tokenizer.decode(token)
# yield text, if any
if text:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
# break if stop sequence found
if token == eos_token_id or (stop is not None and text in stop):
break
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, bool):
tool_choice = formatted_tools[0]
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)