diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index e3a28c73c7f..3cd687e87c2 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -1,11 +1,23 @@ """MLX Chat Wrapper.""" -from typing import Any, Iterator, List, Optional +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Type, + Union, +) from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -20,6 +32,9 @@ from langchain_core.outputs import ( ChatResult, LLMResult, ) +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_community.llms.mlx_pipeline import MLXPipeline @@ -94,7 +109,6 @@ class ChatMLX(BaseChatModel): raise ValueError("Last message must be a HumanMessage!") messages_dicts = [self._to_chatml_format(m) for m in messages] - return self.tokenizer.apply_chat_template( messages_dicts, tokenize=tokenize, @@ -173,15 +187,18 @@ class ChatMLX(BaseChatModel): generate_step( prompt_tokens, self.llm.model, - temp, - repetition_penalty, - repetition_context_size, + temp=temp, + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, ), range(max_new_tokens), ): # identify text to yield text: Optional[str] = None - text = self.tokenizer.decode(token.item()) + if not isinstance(token, int): + text = self.tokenizer.decode(token.item()) + else: + text = self.tokenizer.decode(token) # yield text, if any if text: @@ -193,3 +210,59 @@ class ChatMLX(BaseChatModel): # break if stop sequence found if token == eos_token_id or (stop is not None and text in stop): break + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + *, + tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Supports any tool definition handled by + :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. + tool_choice: Which tool to require the model to call. + Must be the name of the single provided function or + "auto" to automatically determine which function to call + (if any), or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice is not None and tool_choice: + if len(formatted_tools) != 1: + raise ValueError( + "When specifying `tool_choice`, you must provide exactly one " + f"tool. Received {len(formatted_tools)} tools." + ) + if isinstance(tool_choice, str): + if tool_choice not in ("auto", "none"): + tool_choice = { + "type": "function", + "function": {"name": tool_choice}, + } + elif isinstance(tool_choice, bool): + tool_choice = formatted_tools[0] + elif isinstance(tool_choice, dict): + if ( + formatted_tools[0]["function"]["name"] + != tool_choice["function"]["name"] + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tool was {formatted_tools[0]['function']['name']}." + ) + else: + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + kwargs["tool_choice"] = tool_choice + return super().bind(tools=formatted_tools, **kwargs)