mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
Added bind_tools
support for ChatMLX
along with small fix in _stream
(#28743)
- **Description:** Added Support for `bind_tool` as requested in the issue. Plus two issue in `_stream` were fixed: - Corrected the Positional Argument Passing for `generate_step` - Accountability if `token` returned by `generate_step` is integer. - **Issue:** #28692
This commit is contained in:
parent
558b65ea32
commit
8d746086ab
@ -1,11 +1,23 @@
|
|||||||
"""MLX Chat Wrapper."""
|
"""MLX Chat Wrapper."""
|
||||||
|
|
||||||
from typing import Any, Iterator, List, Optional
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -20,6 +32,9 @@ from langchain_core.outputs import (
|
|||||||
ChatResult,
|
ChatResult,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
from langchain_community.llms.mlx_pipeline import MLXPipeline
|
from langchain_community.llms.mlx_pipeline import MLXPipeline
|
||||||
|
|
||||||
@ -94,7 +109,6 @@ class ChatMLX(BaseChatModel):
|
|||||||
raise ValueError("Last message must be a HumanMessage!")
|
raise ValueError("Last message must be a HumanMessage!")
|
||||||
|
|
||||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
messages_dicts,
|
messages_dicts,
|
||||||
tokenize=tokenize,
|
tokenize=tokenize,
|
||||||
@ -173,15 +187,18 @@ class ChatMLX(BaseChatModel):
|
|||||||
generate_step(
|
generate_step(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
self.llm.model,
|
self.llm.model,
|
||||||
temp,
|
temp=temp,
|
||||||
repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
repetition_context_size,
|
repetition_context_size=repetition_context_size,
|
||||||
),
|
),
|
||||||
range(max_new_tokens),
|
range(max_new_tokens),
|
||||||
):
|
):
|
||||||
# identify text to yield
|
# identify text to yield
|
||||||
text: Optional[str] = None
|
text: Optional[str] = None
|
||||||
text = self.tokenizer.decode(token.item())
|
if not isinstance(token, int):
|
||||||
|
text = self.tokenizer.decode(token.item())
|
||||||
|
else:
|
||||||
|
text = self.tokenizer.decode(token)
|
||||||
|
|
||||||
# yield text, if any
|
# yield text, if any
|
||||||
if text:
|
if text:
|
||||||
@ -193,3 +210,59 @@ class ChatMLX(BaseChatModel):
|
|||||||
# break if stop sequence found
|
# break if stop sequence found
|
||||||
if token == eos_token_id or (stop is not None and text in stop):
|
if token == eos_token_id or (stop is not None and text in stop):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind tool-like objects to this chat model.
|
||||||
|
|
||||||
|
Assumes model is compatible with OpenAI tool-calling API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: A list of tool definitions to bind to this chat model.
|
||||||
|
Supports any tool definition handled by
|
||||||
|
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
|
||||||
|
tool_choice: Which tool to require the model to call.
|
||||||
|
Must be the name of the single provided function or
|
||||||
|
"auto" to automatically determine which function to call
|
||||||
|
(if any), or a dict of the form:
|
||||||
|
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||||
|
**kwargs: Any additional parameters to pass to the
|
||||||
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
|
if tool_choice is not None and tool_choice:
|
||||||
|
if len(formatted_tools) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"When specifying `tool_choice`, you must provide exactly one "
|
||||||
|
f"tool. Received {len(formatted_tools)} tools."
|
||||||
|
)
|
||||||
|
if isinstance(tool_choice, str):
|
||||||
|
if tool_choice not in ("auto", "none"):
|
||||||
|
tool_choice = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": tool_choice},
|
||||||
|
}
|
||||||
|
elif isinstance(tool_choice, bool):
|
||||||
|
tool_choice = formatted_tools[0]
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
if (
|
||||||
|
formatted_tools[0]["function"]["name"]
|
||||||
|
!= tool_choice["function"]["name"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool choice {tool_choice} was specified, but the only "
|
||||||
|
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||||
|
f"Received: {tool_choice}"
|
||||||
|
)
|
||||||
|
kwargs["tool_choice"] = tool_choice
|
||||||
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user