Added bind_tools support for ChatMLX along with small fix in _stream (#28743)

- **Description:** Added Support for `bind_tool` as requested in the
issue. Plus two issue in `_stream` were fixed:
    - Corrected the Positional Argument Passing for `generate_step`
    - Accountability if `token` returned by `generate_step` is integer.
- **Issue:** #28692
This commit is contained in:
Mohammad Mohtashim 2024-12-16 19:52:49 +05:00 committed by GitHub
parent 558b65ea32
commit 8d746086ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,11 +1,23 @@
"""MLX Chat Wrapper."""
from typing import Any, Iterator, List, Optional
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
@ -20,6 +32,9 @@ from langchain_core.outputs import (
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_community.llms.mlx_pipeline import MLXPipeline
@ -94,7 +109,6 @@ class ChatMLX(BaseChatModel):
raise ValueError("Last message must be a HumanMessage!")
messages_dicts = [self._to_chatml_format(m) for m in messages]
return self.tokenizer.apply_chat_template(
messages_dicts,
tokenize=tokenize,
@ -173,15 +187,18 @@ class ChatMLX(BaseChatModel):
generate_step(
prompt_tokens,
self.llm.model,
temp,
repetition_penalty,
repetition_context_size,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
),
range(max_new_tokens),
):
# identify text to yield
text: Optional[str] = None
text = self.tokenizer.decode(token.item())
if not isinstance(token, int):
text = self.tokenizer.decode(token.item())
else:
text = self.tokenizer.decode(token)
# yield text, if any
if text:
@ -193,3 +210,59 @@ class ChatMLX(BaseChatModel):
# break if stop sequence found
if token == eos_token_id or (stop is not None and text in stop):
break
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, bool):
tool_choice = formatted_tools[0]
elif isinstance(tool_choice, dict):
if (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)